diff --git a/.rat-excludes b/.rat-excludes
index a4f316a4aaa04..8b5061415ff4c 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -25,6 +25,16 @@ graphlib-dot.min.js
sorttable.js
vis.min.js
vis.min.css
+dataTables.bootstrap.css
+dataTables.bootstrap.min.js
+dataTables.rowsGroup.js
+jquery.blockUI.min.js
+jquery.cookies.2.2.0.min.js
+jquery.dataTables.1.10.4.min.css
+jquery.dataTables.1.10.4.min.js
+jquery.mustache.js
+jsonFormatter.min.css
+jsonFormatter.min.js
.*avsc
.*txt
.*json
@@ -63,12 +73,12 @@ logs
.*dependency-reduced-pom.xml
known_translations
json_expectation
-local-1422981759269/*
-local-1422981780767/*
-local-1425081759269/*
-local-1426533911241/*
-local-1426633911242/*
-local-1430917381534/*
+local-1422981759269
+local-1422981780767
+local-1425081759269
+local-1426533911241
+local-1426633911242
+local-1430917381534
local-1430917381535_1
local-1430917381535_2
DESCRIPTION
diff --git a/LICENSE b/LICENSE
index 9c944ac610afe..9fc29db8d3f22 100644
--- a/LICENSE
+++ b/LICENSE
@@ -291,3 +291,9 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
(MIT License) dagre-d3 (https://github.com/cpettitt/dagre-d3)
(MIT License) sorttable (https://github.com/stuartlangridge/sorttable)
(MIT License) boto (https://github.com/boto/boto/blob/develop/LICENSE)
+ (MIT License) datatables (http://datatables.net/license)
+ (MIT License) mustache (https://github.com/mustache/mustache/blob/master/LICENSE)
+ (MIT License) cookies (http://code.google.com/p/cookies/wiki/License)
+ (MIT License) blockUI (http://jquery.malsup.com/block/)
+ (MIT License) RowsGroup (http://datatables.net/license/mit)
+ (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html)
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 6c79f9189787d..477d4931c3a88 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../pom.xmlorg.apache.spark
- spark-assembly_2.10
+ spark-assembly_2.11Spark Project Assemblyhttp://spark.apache.org/pom
diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml
index 67723fa421ab1..442043cb51164 100644
--- a/common/sketch/pom.xml
+++ b/common/sketch/pom.xml
@@ -21,13 +21,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-sketch_2.10
+ spark-sketch_2.11jarSpark Project Sketchhttp://spark.apache.org/
@@ -35,6 +35,13 @@
sketch
+
+
+ org.apache.spark
+ spark-test-tags_${scala.binary.version}
+
+
+
target/scala-${scala.binary.version}/classestarget/scala-${scala.binary.version}/test-classes
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java
index 2a0484e324b13..480a0a79db32d 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java
@@ -22,7 +22,7 @@
import java.io.IOException;
import java.util.Arrays;
-public final class BitArray {
+final class BitArray {
private final long[] data;
private long bitCount;
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
index 81772fcea0ec2..c0b425e729595 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
@@ -22,16 +22,10 @@
import java.io.OutputStream;
/**
- * A Bloom filter is a space-efficient probabilistic data structure, that is used to test whether
- * an element is a member of a set. It returns false when the element is definitely not in the
- * set, returns true when the element is probably in the set.
- *
- * Internally a Bloom filter is initialized with 2 information: how many space to use(number of
- * bits) and how many hash values to calculate for each record. To get as lower false positive
- * probability as possible, user should call {@link BloomFilter#create} to automatically pick a
- * best combination of these 2 parameters.
- *
- * Currently the following data types are supported:
+ * A Bloom filter is a space-efficient probabilistic data structure that offers an approximate
+ * containment test with one-sided error: if it claims that an item is contained in it, this
+ * might be in error, but if it claims that an item is not contained in it, then this is
+ * definitely true. Currently supported data types include:
*
*
{@link Byte}
*
{@link Short}
@@ -39,14 +33,17 @@
*
{@link Long}
*
{@link String}
*
+ * The false positive probability ({@code FPP}) of a Bloom filter is defined as the probability that
+ * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that hasu
+ * not actually been put in the {@code BloomFilter}.
*
- * The implementation is largely based on the {@code BloomFilter} class from guava.
+ * The implementation is largely based on the {@code BloomFilter} class from Guava.
*/
public abstract class BloomFilter {
public enum Version {
/**
- * {@code BloomFilter} binary format version 1 (all values written in big-endian order):
+ * {@code BloomFilter} binary format version 1. All values written in big-endian order:
*
*
Version number, always 1 (32 bit)
*
Number of hash functions (32 bit)
@@ -68,14 +65,13 @@ int getVersionNumber() {
}
/**
- * Returns the false positive probability, i.e. the probability that
- * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that
- * has not actually been put in the {@code BloomFilter}.
+ * Returns the probability that {@linkplain #mightContain(Object)} erroneously return {@code true}
+ * for an object that has not actually been put in the {@code BloomFilter}.
*
- *
Ideally, this number should be close to the {@code fpp} parameter
- * passed in to create this bloom filter, or smaller. If it is
- * significantly higher, it is usually the case that too many elements (more than
- * expected) have been put in the {@code BloomFilter}, degenerating it.
+ * Ideally, this number should be close to the {@code fpp} parameter passed in
+ * {@linkplain #create(long, double)}, or smaller. If it is significantly higher, it is usually
+ * the case that too many items (more than expected) have been put in the {@code BloomFilter},
+ * degenerating it.
*/
public abstract double expectedFpp();
@@ -85,8 +81,8 @@ int getVersionNumber() {
public abstract long bitSize();
/**
- * Puts an element into this {@code BloomFilter}. Ensures that subsequent invocations of
- * {@link #mightContain(Object)} with the same element will always return {@code true}.
+ * Puts an item into this {@code BloomFilter}. Ensures that subsequent invocations of
+ * {@linkplain #mightContain(Object)} with the same item will always return {@code true}.
*
* @return true if the bloom filter's bits changed as a result of this operation. If the bits
* changed, this is definitely the first time {@code object} has been added to the
@@ -98,19 +94,19 @@ int getVersionNumber() {
public abstract boolean put(Object item);
/**
- * A specialized variant of {@link #put(Object)}, that can only be used to put utf-8 string.
+ * A specialized variant of {@link #put(Object)} that only supports {@code String} items.
*/
- public abstract boolean putString(String str);
+ public abstract boolean putString(String item);
/**
- * A specialized variant of {@link #put(Object)}, that can only be used to put long.
+ * A specialized variant of {@link #put(Object)} that only supports {@code long} items.
*/
- public abstract boolean putLong(long l);
+ public abstract boolean putLong(long item);
/**
- * A specialized variant of {@link #put(Object)}, that can only be used to put byte array.
+ * A specialized variant of {@link #put(Object)} that only supports byte array items.
*/
- public abstract boolean putBinary(byte[] bytes);
+ public abstract boolean putBinary(byte[] item);
/**
* Determines whether a given bloom filter is compatible with this bloom filter. For two
@@ -137,38 +133,36 @@ int getVersionNumber() {
public abstract boolean mightContain(Object item);
/**
- * A specialized variant of {@link #mightContain(Object)}, that can only be used to test utf-8
- * string.
+ * A specialized variant of {@link #mightContain(Object)} that only tests {@code String} items.
*/
- public abstract boolean mightContainString(String str);
+ public abstract boolean mightContainString(String item);
/**
- * A specialized variant of {@link #mightContain(Object)}, that can only be used to test long.
+ * A specialized variant of {@link #mightContain(Object)} that only tests {@code long} items.
*/
- public abstract boolean mightContainLong(long l);
+ public abstract boolean mightContainLong(long item);
/**
- * A specialized variant of {@link #mightContain(Object)}, that can only be used to test byte
- * array.
+ * A specialized variant of {@link #mightContain(Object)} that only tests byte array items.
*/
- public abstract boolean mightContainBinary(byte[] bytes);
+ public abstract boolean mightContainBinary(byte[] item);
/**
- * Writes out this {@link BloomFilter} to an output stream in binary format.
- * It is the caller's responsibility to close the stream.
+ * Writes out this {@link BloomFilter} to an output stream in binary format. It is the caller's
+ * responsibility to close the stream.
*/
public abstract void writeTo(OutputStream out) throws IOException;
/**
- * Reads in a {@link BloomFilter} from an input stream.
- * It is the caller's responsibility to close the stream.
+ * Reads in a {@link BloomFilter} from an input stream. It is the caller's responsibility to close
+ * the stream.
*/
public static BloomFilter readFrom(InputStream in) throws IOException {
return BloomFilterImpl.readFrom(in);
}
/**
- * Computes the optimal k (number of hashes per element inserted in Bloom filter), given the
+ * Computes the optimal k (number of hashes per item inserted in Bloom filter), given the
* expected insertions and total number of bits in the Bloom filter.
*
* See http://en.wikipedia.org/wiki/File:Bloom_filter_fp_probability.svg for the formula.
@@ -197,21 +191,31 @@ private static long optimalNumOfBits(long n, double p) {
static final double DEFAULT_FPP = 0.03;
/**
- * Creates a {@link BloomFilter} with given {@code expectedNumItems} and the default {@code fpp}.
+ * Creates a {@link BloomFilter} with the expected number of insertions and a default expected
+ * false positive probability of 3%.
+ *
+ * Note that overflowing a {@code BloomFilter} with significantly more elements than specified,
+ * will result in its saturation, and a sharp deterioration of its false positive probability.
*/
public static BloomFilter create(long expectedNumItems) {
return create(expectedNumItems, DEFAULT_FPP);
}
/**
- * Creates a {@link BloomFilter} with given {@code expectedNumItems} and {@code fpp}, it will pick
- * an optimal {@code numBits} and {@code numHashFunctions} for the bloom filter.
+ * Creates a {@link BloomFilter} with the expected number of insertions and expected false
+ * positive probability.
+ *
+ * Note that overflowing a {@code BloomFilter} with significantly more elements than specified,
+ * will result in its saturation, and a sharp deterioration of its false positive probability.
*/
public static BloomFilter create(long expectedNumItems, double fpp) {
- assert fpp > 0.0 : "False positive probability must be > 0.0";
- assert fpp < 1.0 : "False positive probability must be < 1.0";
- long numBits = optimalNumOfBits(expectedNumItems, fpp);
- return create(expectedNumItems, numBits);
+ if (fpp <= 0D || fpp >= 1D) {
+ throw new IllegalArgumentException(
+ "False positive probability must be within range (0.0, 1.0)"
+ );
+ }
+
+ return create(expectedNumItems, optimalNumOfBits(expectedNumItems, fpp));
}
/**
@@ -219,9 +223,14 @@ public static BloomFilter create(long expectedNumItems, double fpp) {
* pick an optimal {@code numHashFunctions} which can minimize {@code fpp} for the bloom filter.
*/
public static BloomFilter create(long expectedNumItems, long numBits) {
- assert expectedNumItems > 0 : "Expected insertions must be > 0";
- assert numBits > 0 : "number of bits must be > 0";
- int numHashFunctions = optimalNumOfHashFunctions(expectedNumItems, numBits);
- return new BloomFilterImpl(numHashFunctions, numBits);
+ if (expectedNumItems <= 0) {
+ throw new IllegalArgumentException("Expected insertions must be positive");
+ }
+
+ if (numBits <= 0) {
+ throw new IllegalArgumentException("Number of bits must be positive");
+ }
+
+ return new BloomFilterImpl(optimalNumOfHashFunctions(expectedNumItems, numBits), numBits);
}
}
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
index 35107e0b389d7..92c28bcb56a5a 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
@@ -19,9 +19,10 @@
import java.io.*;
-public class BloomFilterImpl extends BloomFilter implements Serializable {
+class BloomFilterImpl extends BloomFilter implements Serializable {
private int numHashFunctions;
+
private BitArray bits;
BloomFilterImpl(int numHashFunctions, long numBits) {
@@ -77,14 +78,14 @@ public boolean put(Object item) {
}
@Override
- public boolean putString(String str) {
- return putBinary(Utils.getBytesFromUTF8String(str));
+ public boolean putString(String item) {
+ return putBinary(Utils.getBytesFromUTF8String(item));
}
@Override
- public boolean putBinary(byte[] bytes) {
- int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0);
- int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1);
+ public boolean putBinary(byte[] item) {
+ int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0);
+ int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1);
long bitSize = bits.bitSize();
boolean bitsChanged = false;
@@ -100,14 +101,14 @@ public boolean putBinary(byte[] bytes) {
}
@Override
- public boolean mightContainString(String str) {
- return mightContainBinary(Utils.getBytesFromUTF8String(str));
+ public boolean mightContainString(String item) {
+ return mightContainBinary(Utils.getBytesFromUTF8String(item));
}
@Override
- public boolean mightContainBinary(byte[] bytes) {
- int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0);
- int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1);
+ public boolean mightContainBinary(byte[] item) {
+ int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0);
+ int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1);
long bitSize = bits.bitSize();
for (int i = 1; i <= numHashFunctions; i++) {
@@ -124,14 +125,14 @@ public boolean mightContainBinary(byte[] bytes) {
}
@Override
- public boolean putLong(long l) {
+ public boolean putLong(long item) {
// Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n
// hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions.
// Note that `CountMinSketch` use a different strategy, it hash the input long element with
// every i to produce n hash values.
// TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here?
- int h1 = Murmur3_x86_32.hashLong(l, 0);
- int h2 = Murmur3_x86_32.hashLong(l, h1);
+ int h1 = Murmur3_x86_32.hashLong(item, 0);
+ int h2 = Murmur3_x86_32.hashLong(item, h1);
long bitSize = bits.bitSize();
boolean bitsChanged = false;
@@ -147,9 +148,9 @@ public boolean putLong(long l) {
}
@Override
- public boolean mightContainLong(long l) {
- int h1 = Murmur3_x86_32.hashLong(l, 0);
- int h2 = Murmur3_x86_32.hashLong(l, h1);
+ public boolean mightContainLong(long item) {
+ int h1 = Murmur3_x86_32.hashLong(item, 0);
+ int h2 = Murmur3_x86_32.hashLong(item, h1);
long bitSize = bits.bitSize();
for (int i = 1; i <= numHashFunctions; i++) {
@@ -197,7 +198,7 @@ public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeExcep
throw new IncompatibleMergeException("Cannot merge null bloom filter");
}
- if (!(other instanceof BloomFilter)) {
+ if (!(other instanceof BloomFilterImpl)) {
throw new IncompatibleMergeException(
"Cannot merge bloom filter of class " + other.getClass().getName()
);
@@ -211,7 +212,8 @@ public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeExcep
if (this.numHashFunctions != that.numHashFunctions) {
throw new IncompatibleMergeException(
- "Cannot merge bloom filters with different number of hash functions");
+ "Cannot merge bloom filters with different number of hash functions"
+ );
}
this.bits.putAll(that.bits);
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
index 5692e574d4c7e..48f98680f48ca 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -22,7 +22,7 @@
import java.io.OutputStream;
/**
- * A Count-Min sketch is a probabilistic data structure used for summarizing streams of data in
+ * A Count-min sketch is a probabilistic data structure used for summarizing streams of data in
* sub-linear space. Currently, supported data types include:
*
*
{@link Byte}
@@ -31,8 +31,7 @@
*
{@link Long}
*
{@link String}
*
- * Each {@link CountMinSketch} is initialized with a random seed, and a pair
- * of parameters:
+ * A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters:
*
*
*
- * See http://www.eecs.harvard.edu/~michaelm/CS222/countmin.pdf for technical details,
- * including proofs of the estimates and error bounds used in this implementation.
- *
* This implementation is largely based on the {@code CountMinSketch} class from stream-lib.
*/
abstract public class CountMinSketch {
public enum Version {
/**
- * {@code CountMinSketch} binary format version 1 (all values written in big-endian order):
+ * {@code CountMinSketch} binary format version 1. All values written in big-endian order:
*
- }
-
- private def appRow(info: ApplicationHistoryInfo): Seq[Node] = {
- attemptRow(false, info, info.attempts.head, true)
- }
-
- private def appWithAttemptRow(info: ApplicationHistoryInfo): Seq[Node] = {
- attemptRow(true, info, info.attempts.head, true) ++
- info.attempts.drop(1).flatMap(attemptRow(true, info, _, false))
- }
-
- private def makePageLink(linkPage: Int, showIncomplete: Boolean): String = {
- UIUtils.prependBaseUri("/?" + Array(
- "page=" + linkPage,
- "showIncomplete=" + showIncomplete
- ).mkString("&"))
+ private def makePageLink(showIncomplete: Boolean): String = {
+ UIUtils.prependBaseUri("/?" + "showIncomplete=" + showIncomplete)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index 7e2cf956c7253..4ffb5283e99a4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -65,7 +65,7 @@ private[spark] class ApplicationInfo(
appSource = new ApplicationSource(this)
nextExecutorId = 0
removedExecutors = new ArrayBuffer[ExecutorDesc]
- executorLimit = Integer.MAX_VALUE
+ executorLimit = desc.initialExecutorLimit.getOrElse(Integer.MAX_VALUE)
appUIUrlAtHistoryServer = None
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 202a1b787c21b..0f11f680b3914 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -74,7 +74,7 @@ private[deploy] class Master(
val workers = new HashSet[WorkerInfo]
val idToApp = new HashMap[String, ApplicationInfo]
- val waitingApps = new ArrayBuffer[ApplicationInfo]
+ private val waitingApps = new ArrayBuffer[ApplicationInfo]
val apps = new HashSet[ApplicationInfo]
private val idToWorker = new HashMap[String, WorkerInfo]
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala
index 66a9ff38678c6..39b2647a900f0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala
@@ -42,6 +42,6 @@ private[spark] class MasterSource(val master: Master) extends Source {
// Gauge for waiting application numbers in cluster
metricRegistry.register(MetricRegistry.name("waitingApps"), new Gauge[Int] {
- override def getValue: Int = master.waitingApps.size
+ override def getValue: Int = master.apps.filter(_.state == ApplicationState.WAITING).size
})
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala
index 66e1e645007a7..9b31497adfb12 100644
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala
@@ -50,7 +50,7 @@ private[mesos] class MesosClusterDispatcher(
extends Logging {
private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host)
- private val recoveryMode = conf.get("spark.mesos.deploy.recoveryMode", "NONE").toUpperCase()
+ private val recoveryMode = conf.get("spark.deploy.recoveryMode", "NONE").toUpperCase()
logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode)
private val engineFactory = recoveryMode match {
@@ -98,8 +98,8 @@ private[mesos] object MesosClusterDispatcher extends Logging {
conf.setMaster(dispatcherArgs.masterUrl)
conf.setAppName(dispatcherArgs.name)
dispatcherArgs.zookeeperUrl.foreach { z =>
- conf.set("spark.mesos.deploy.recoveryMode", "ZOOKEEPER")
- conf.set("spark.mesos.deploy.zookeeper.url", z)
+ conf.set("spark.deploy.recoveryMode", "ZOOKEEPER")
+ conf.set("spark.deploy.zookeeper.url", z)
}
val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf)
dispatcher.start()
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 179d3b9f20b1f..df3c286a0a66f 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -394,7 +394,7 @@ private[deploy] class Worker(
// rpcEndpoint.
// Copy ids so that it can be used in the cleanup thread.
val appIds = executors.values.map(_.appId).toSet
- val cleanupFuture = concurrent.future {
+ val cleanupFuture = concurrent.Future {
val appDirs = workDir.listFiles()
if (appDirs == null) {
throw new IOException("ERROR: Failed to list files in " + appDirs)
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 136cf4a84d387..3b5cb18da1b26 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -19,6 +19,7 @@ package org.apache.spark.executor
import java.net.URL
import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable
import scala.util.{Failure, Success}
@@ -42,6 +43,7 @@ private[spark] class CoarseGrainedExecutorBackend(
env: SparkEnv)
extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {
+ private[this] val stopping = new AtomicBoolean(false)
var executor: Executor = null
@volatile var driver: Option[RpcEndpointRef] = None
@@ -102,19 +104,23 @@ private[spark] class CoarseGrainedExecutorBackend(
}
case StopExecutor =>
+ stopping.set(true)
logInfo("Driver commanded a shutdown")
// Cannot shutdown here because an ack may need to be sent back to the caller. So send
// a message to self to actually do the shutdown.
self.send(Shutdown)
case Shutdown =>
+ stopping.set(true)
executor.stop()
stop()
rpcEnv.shutdown()
}
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
- if (driver.exists(_.address == remoteAddress)) {
+ if (stopping.get()) {
+ logInfo(s"Driver from $remoteAddress disconnected during shutdown")
+ } else if (driver.exists(_.address == remoteAddress)) {
logError(s"Driver $remoteAddress disassociated! Shutting down.")
System.exit(1)
} else {
diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
index ed9e157ce758b..6d30d3c76a9fb 100644
--- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
@@ -81,10 +81,15 @@ class InputMetrics private (
*/
def readMethod: DataReadMethod.Value = DataReadMethod.withName(_readMethod.localValue)
+ // Once incBytesRead & intRecordsRead is ready to be removed from the public API
+ // we can remove the internal versions and make the previous public API private.
+ // This has been done to suppress warnings when building.
@deprecated("incrementing input metrics is for internal use only", "2.0.0")
def incBytesRead(v: Long): Unit = _bytesRead.add(v)
+ private[spark] def incBytesReadInternal(v: Long): Unit = _bytesRead.add(v)
@deprecated("incrementing input metrics is for internal use only", "2.0.0")
def incRecordsRead(v: Long): Unit = _recordsRead.add(v)
+ private[spark] def incRecordsReadInternal(v: Long): Unit = _recordsRead.add(v)
private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v)
private[spark] def setReadMethod(v: DataReadMethod.Value): Unit =
_readMethod.setValue(v.toString)
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 822bad9f8f52f..8ff0620f837c9 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -337,8 +337,8 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se
* field is always empty, since this represents the partial updates recorded in this task,
* not the aggregated value across multiple tasks.
*/
- def accumulatorUpdates(): Seq[AccumulableInfo] = accums.map { a =>
- new AccumulableInfo(a.id, a.name, Some(a.localValue), None, a.isInternal, a.countFailedValues)
+ def accumulatorUpdates(): Seq[AccumulableInfo] = {
+ accums.map { a => a.toInfo(Some(a.localValue), None) }
}
// If we are reconstructing this TaskMetrics on the driver, some metrics may already be set.
diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
index c07f346bbafd5..bd61d04d42f05 100644
--- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
@@ -103,7 +103,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
* If the RDD contains infinity, NaN throws an exception
* If the elements in RDD do not vary (max == min) always returns a single bucket.
*/
- def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = self.withScope {
+ def histogram(bucketCount: Int): (Array[Double], Array[Long]) = self.withScope {
// Scala's built-in range has issues. See #SI-8782
def customRange(min: Double, max: Double, steps: Int): IndexedSeq[Double] = {
val span = max - min
@@ -112,7 +112,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
// Compute the minimum and the maximum
val (max: Double, min: Double) = self.mapPartitions { items =>
Iterator(items.foldRight(Double.NegativeInfinity,
- Double.PositiveInfinity)((e: Double, x: Pair[Double, Double]) =>
+ Double.PositiveInfinity)((e: Double, x: (Double, Double)) =>
(x._1.max(e), x._2.min(e))))
}.reduce { (maxmin1, maxmin2) =>
(maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2))
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 3204e6adceca2..805cd9fe1f638 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -215,6 +215,7 @@ class HadoopRDD[K, V](
// TODO: there is a lot of duplicate code between this and NewHadoopRDD and SqlNewHadoopRDD
val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop)
+ val existingBytesRead = inputMetrics.bytesRead
// Sets the thread local variable for the file's name
split.inputSplit.value match {
@@ -230,9 +231,13 @@ class HadoopRDD[K, V](
case _ => None
}
+ // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics.
+ // If we do a coalesce, however, we are likely to compute multiple partitions in the same
+ // task and in the same thread, in which case we need to avoid override values written by
+ // previous partitions (SPARK-13071).
def updateBytesRead(): Unit = {
getBytesReadCallback.foreach { getBytesRead =>
- inputMetrics.setBytesRead(getBytesRead())
+ inputMetrics.setBytesRead(existingBytesRead + getBytesRead())
}
}
@@ -255,7 +260,7 @@ class HadoopRDD[K, V](
finished = true
}
if (!finished) {
- inputMetrics.incRecordsRead(1)
+ inputMetrics.incRecordsReadInternal(1)
}
if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
updateBytesRead()
@@ -287,7 +292,7 @@ class HadoopRDD[K, V](
// If we can't get the bytes read from the FS stats, fall back to the split size,
// which may be inaccurate.
try {
- inputMetrics.incBytesRead(split.inputSplit.value.getLength)
+ inputMetrics.incBytesReadInternal(split.inputSplit.value.getLength)
} catch {
case e: java.io.IOException =>
logWarning("Unable to get input size to set InputMetrics for task", e)
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 4d2816e335fe3..f23da39eb90de 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -130,6 +130,7 @@ class NewHadoopRDD[K, V](
val conf = getConf
val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop)
+ val existingBytesRead = inputMetrics.bytesRead
// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
@@ -139,9 +140,13 @@ class NewHadoopRDD[K, V](
case _ => None
}
+ // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics.
+ // If we do a coalesce, however, we are likely to compute multiple partitions in the same
+ // task and in the same thread, in which case we need to avoid override values written by
+ // previous partitions (SPARK-13071).
def updateBytesRead(): Unit = {
getBytesReadCallback.foreach { getBytesRead =>
- inputMetrics.setBytesRead(getBytesRead())
+ inputMetrics.setBytesRead(existingBytesRead + getBytesRead())
}
}
@@ -183,7 +188,7 @@ class NewHadoopRDD[K, V](
}
havePair = false
if (!finished) {
- inputMetrics.incRecordsRead(1)
+ inputMetrics.incRecordsReadInternal(1)
}
if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
updateBytesRead()
@@ -214,7 +219,7 @@ class NewHadoopRDD[K, V](
// If we can't get the bytes read from the FS stats, fall back to the split size,
// which may be inaccurate.
try {
- inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength)
+ inputMetrics.incBytesReadInternal(split.serializableHadoopSplit.value.getLength)
} catch {
case e: java.io.IOException =>
logWarning("Unable to get input size to set InputMetrics for task", e)
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 33f2f0b44f773..61905a8421124 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -726,6 +726,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
*
* Warning: this doesn't return a multimap (so if you have multiple values to the same key, only
* one value per key is preserved in the map returned)
+ *
+ * @note this method should only be used if the resulting data is expected to be small, as
+ * all the data is loaded into the driver's memory.
*/
def collectAsMap(): Map[K, V] = self.withScope {
val data = self.collect()
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index be47172581b7f..a81a98b526b5a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -481,6 +481,9 @@ abstract class RDD[T: ClassTag](
/**
* Return a fixed-size sampled subset of this RDD in an array
*
+ * @note this method should only be used if the resulting array is expected to be small, as
+ * all the data is loaded into the driver's memory.
+ *
* @param withReplacement whether sampling is done with replacement
* @param num size of the returned sample
* @param seed seed for the random number generator
@@ -836,6 +839,9 @@ abstract class RDD[T: ClassTag](
/**
* Return an array that contains all of the elements in this RDD.
+ *
+ * @note this method should only be used if the resulting array is expected to be small, as
+ * all the data is loaded into the driver's memory.
*/
def collect(): Array[T] = withScope {
val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
@@ -1202,6 +1208,9 @@ abstract class RDD[T: ClassTag](
* results from that partition to estimate the number of additional partitions needed to satisfy
* the limit.
*
+ * @note this method should only be used if the resulting array is expected to be small, as
+ * all the data is loaded into the driver's memory.
+ *
* @note due to complications in the internal implementation, this method will raise
* an exception if called on an RDD of `Nothing` or `Null`.
*/
@@ -1263,6 +1272,9 @@ abstract class RDD[T: ClassTag](
* // returns Array(6, 5)
* }}}
*
+ * @note this method should only be used if the resulting array is expected to be small, as
+ * all the data is loaded into the driver's memory.
+ *
* @param num k, the number of top elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
@@ -1283,6 +1295,9 @@ abstract class RDD[T: ClassTag](
* // returns Array(2, 3)
* }}}
*
+ * @note this method should only be used if the resulting array is expected to be small, as
+ * all the data is loaded into the driver's memory.
+ *
* @param num k, the number of elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
@@ -1542,6 +1557,15 @@ abstract class RDD[T: ClassTag](
private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
+ // Whether to checkpoint all ancestor RDDs that are marked for checkpointing. By default,
+ // we stop as soon as we find the first such RDD, an optimization that allows us to write
+ // less data but is not safe for all workloads. E.g. in streaming we may checkpoint both
+ // an RDD and its parent in every batch, in which case the parent may never be checkpointed
+ // and its lineage never truncated, leading to OOMs in the long run (SPARK-6847).
+ private val checkpointAllMarkedAncestors =
+ Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS))
+ .map(_.toBoolean).getOrElse(false)
+
/** Returns the first parent RDD */
protected[spark] def firstParent[U: ClassTag]: RDD[U] = {
dependencies.head.rdd.asInstanceOf[RDD[U]]
@@ -1585,6 +1609,13 @@ abstract class RDD[T: ClassTag](
if (!doCheckpointCalled) {
doCheckpointCalled = true
if (checkpointData.isDefined) {
+ if (checkpointAllMarkedAncestors) {
+ // TODO We can collect all the RDDs that needs to be checkpointed, and then checkpoint
+ // them in parallel.
+ // Checkpoint parents first because our lineage will be truncated after we
+ // checkpoint ourselves
+ dependencies.foreach(_.rdd.doCheckpoint())
+ }
checkpointData.get.checkpoint()
} else {
dependencies.foreach(_.rdd.doCheckpoint())
@@ -1704,6 +1735,9 @@ abstract class RDD[T: ClassTag](
*/
object RDD {
+ private[spark] val CHECKPOINT_ALL_MARKED_ANCESTORS =
+ "spark.checkpoint.checkpointAllMarkedAncestors"
+
// The following implicit functions were in SparkContext before 1.3 and users had to
// `import SparkContext._` to enable them. Now we move them here to make the compiler find
// them automatically. However, we still keep the old functions in SparkContext for backward
diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
index 9d45fff9213c6..cedacad44afec 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
@@ -35,6 +35,7 @@ import org.apache.spark.annotation.DeveloperApi
* @param value total accumulated value so far, maybe None if used on executors to describe a task
* @param internal whether this accumulator was internal
* @param countFailedValues whether to count this accumulator's partial value if the task failed
+ * @param metadata internal metadata associated with this accumulator, if any
*/
@DeveloperApi
case class AccumulableInfo private[spark] (
@@ -43,7 +44,9 @@ case class AccumulableInfo private[spark] (
update: Option[Any], // represents a partial update within a task
value: Option[Any],
private[spark] val internal: Boolean,
- private[spark] val countFailedValues: Boolean)
+ private[spark] val countFailedValues: Boolean,
+ // TODO: use this to identify internal task metrics instead of encoding it in the name
+ private[spark] val metadata: Option[String] = None)
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 897479b50010d..ee0b8a1c95fd8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1101,11 +1101,8 @@ class DAGScheduler(
acc ++= partialValue
// To avoid UI cruft, ignore cases where value wasn't updated
if (acc.name.isDefined && partialValue != acc.zero) {
- val name = acc.name
- stage.latestInfo.accumulables(id) = new AccumulableInfo(
- id, name, None, Some(acc.value), acc.isInternal, acc.countFailedValues)
- event.taskInfo.accumulables += new AccumulableInfo(
- id, name, Some(partialValue), Some(acc.value), acc.isInternal, acc.countFailedValues)
+ stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value))
+ event.taskInfo.accumulables += acc.toInfo(Some(partialValue), Some(acc.value))
}
}
} catch {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index 36f2b74f948f1..01fee46e73a80 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -232,8 +232,6 @@ private[spark] object EventLoggingListener extends Logging {
// Suffix applied to the names of files still being written by applications.
val IN_PROGRESS = ".inprogress"
val DEFAULT_LOG_DIR = "/tmp/spark-events"
- val SPARK_VERSION_KEY = "SPARK_VERSION"
- val COMPRESSION_CODEC_KEY = "COMPRESSION_CODEC"
private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 28974740e91d5..0a45ef5283326 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -271,7 +271,7 @@ class StatsReportListener extends SparkListener with Logging {
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
implicit val sc = stageCompleted
- this.logInfo("Finished stage: " + stageCompleted.stageInfo)
+ this.logInfo(s"Finished stage: ${getStatusDetail(stageCompleted.stageInfo)}")
showMillisDistribution("task runtime:", (info, _) => Some(info.duration), taskInfoMetrics)
// Shuffle write
@@ -298,6 +298,17 @@ class StatsReportListener extends SparkListener with Logging {
taskInfoMetrics.clear()
}
+ private def getStatusDetail(info: StageInfo): String = {
+ val failureReason = info.failureReason.map("(" + _ + ")").getOrElse("")
+ val timeTaken = info.submissionTime.map(
+ x => info.completionTime.getOrElse(System.currentTimeMillis()) - x
+ ).getOrElse("-")
+
+ s"Stage(${info.stageId}, ${info.attemptId}); Name: '${info.name}'; " +
+ s"Status: ${info.getStatusString}$failureReason; numTasks: ${info.numTasks}; " +
+ s"Took: $timeTaken msec"
+ }
+
}
private[spark] object StatsReportListener extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index 16f33163789ab..d209645610c12 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -19,11 +19,11 @@ package org.apache.spark.scheduler.cluster
import java.util.concurrent.Semaphore
-import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv}
+import org.apache.spark.{Logging, SparkConf, SparkContext}
import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.deploy.client.{AppClient, AppClientListener}
import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle}
-import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress}
+import org.apache.spark.rpc.RpcEndpointAddress
import org.apache.spark.scheduler._
import org.apache.spark.util.Utils
@@ -89,8 +89,16 @@ private[spark] class SparkDeploySchedulerBackend(
args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts)
val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("")
val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt)
- val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory,
- command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor)
+ // If we're using dynamic allocation, set our initial executor limit to 0 for now.
+ // ExecutorAllocationManager will send the real initial limit to the Master later.
+ val initialExecutorLimit =
+ if (Utils.isDynamicAllocationEnabled(conf)) {
+ Some(0)
+ } else {
+ None
+ }
+ val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command,
+ appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit)
client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf)
client.start()
launcherBackend.setState(SparkAppHandle.State.SUBMITTED)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 58c30e7d97886..0a2d72f4dcb4b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -19,11 +19,13 @@ package org.apache.spark.scheduler.cluster.mesos
import java.io.File
import java.util.{Collections, List => JList}
+import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock
import scala.collection.JavaConverters._
import scala.collection.mutable.{HashMap, HashSet}
+import com.google.common.base.Stopwatch
import com.google.common.collect.HashBiMap
import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver}
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _}
@@ -60,6 +62,13 @@ private[spark] class CoarseMesosSchedulerBackend(
// Maximum number of cores to acquire (TODO: we'll need more flexible controls here)
val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt
+ private[this] val shutdownTimeoutMS =
+ conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s")
+ .ensuring(_ >= 0, "spark.mesos.coarse.shutdownTimeout must be >= 0")
+
+ // Synchronization protected by stateLock
+ private[this] var stopCalled: Boolean = false
+
// If shuffle service is enabled, the Spark driver will register with the shuffle service.
// This is for cleaning up shuffle files reliably.
private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)
@@ -78,10 +87,17 @@ private[spark] class CoarseMesosSchedulerBackend(
val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int]
/**
- * The total number of executors we aim to have. Undefined when not using dynamic allocation
- * and before the ExecutorAllocatorManager calls [[doRequestTotalExecutors]].
+ * The total number of executors we aim to have. Undefined when not using dynamic allocation.
+ * Initially set to 0 when using dynamic allocation, the executor allocation manager will send
+ * the real initial limit later.
*/
- private var executorLimitOption: Option[Int] = None
+ private var executorLimitOption: Option[Int] = {
+ if (Utils.isDynamicAllocationEnabled(conf)) {
+ Some(0)
+ } else {
+ None
+ }
+ }
/**
* Return the current executor limit, which may be [[Int.MaxValue]]
@@ -179,7 +195,7 @@ private[spark] class CoarseMesosSchedulerBackend(
.orElse(Option(System.getenv("SPARK_EXECUTOR_URI")))
if (uri.isEmpty) {
- val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath
+ val runScript = new File(executorSparkHome, "./bin/spark-class").getPath
command.setValue(
"%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend"
.format(prefixEnv, runScript) +
@@ -245,6 +261,13 @@ private[spark] class CoarseMesosSchedulerBackend(
*/
override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
stateLock.synchronized {
+ if (stopCalled) {
+ logDebug("Ignoring offers during shutdown")
+ // Driver should simply return a stopped status on race
+ // condition between this.stop() and completing here
+ offers.asScala.map(_.getId).foreach(d.declineOffer)
+ return
+ }
val filters = Filters.newBuilder().setRefuseSeconds(5).build()
for (offer <- offers.asScala) {
val offerAttributes = toAttributeMap(offer.getAttributesList)
@@ -364,7 +387,29 @@ private[spark] class CoarseMesosSchedulerBackend(
}
override def stop() {
- super.stop()
+ // Make sure we're not launching tasks during shutdown
+ stateLock.synchronized {
+ if (stopCalled) {
+ logWarning("Stop called multiple times, ignoring")
+ return
+ }
+ stopCalled = true
+ super.stop()
+ }
+ // Wait for executors to report done, or else mesosDriver.stop() will forcefully kill them.
+ // See SPARK-12330
+ val stopwatch = new Stopwatch()
+ stopwatch.start()
+ // slaveIdsWithExecutors has no memory barrier, so this is eventually consistent
+ while (slaveIdsWithExecutors.nonEmpty &&
+ stopwatch.elapsed(TimeUnit.MILLISECONDS) < shutdownTimeoutMS) {
+ Thread.sleep(100)
+ }
+ if (slaveIdsWithExecutors.nonEmpty) {
+ logWarning(s"Timed out waiting for ${slaveIdsWithExecutors.size} remaining executors "
+ + s"to terminate within $shutdownTimeoutMS ms. This may leave temporary files "
+ + "on the mesos nodes.")
+ }
if (mesosDriver != null) {
mesosDriver.stop()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala
index e0c547dce6d07..092d9e4182530 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala
@@ -53,9 +53,9 @@ private[spark] trait MesosClusterPersistenceEngine {
* all of them reuses the same connection pool.
*/
private[spark] class ZookeeperMesosClusterPersistenceEngineFactory(conf: SparkConf)
- extends MesosClusterPersistenceEngineFactory(conf) {
+ extends MesosClusterPersistenceEngineFactory(conf) with Logging {
- lazy val zk = SparkCuratorUtil.newClient(conf, "spark.mesos.deploy.zookeeper.url")
+ lazy val zk = SparkCuratorUtil.newClient(conf)
def createEngine(path: String): MesosClusterPersistenceEngine = {
new ZookeeperMesosClusterPersistenceEngine(path, zk, conf)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
index 05fda0fded7f8..8cda4ff0eb3b3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
@@ -394,7 +394,7 @@ private[spark] class MesosClusterScheduler(
.getOrElse {
throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!")
}
- val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getCanonicalPath
+ val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getPath
// Sandbox points to the current directory by default with Mesos.
(cmdExecutable, ".")
}
@@ -573,6 +573,7 @@ private[spark] class MesosClusterScheduler(
override def slaveLost(driver: SchedulerDriver, slaveId: SlaveID): Unit = {}
override def error(driver: SchedulerDriver, error: String): Unit = {
logError("Error received: " + error)
+ markErr()
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index eaf0cb06d6c73..340f29bac9218 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -125,7 +125,7 @@ private[spark] class MesosSchedulerBackend(
val executorBackendName = classOf[MesosExecutorBackend].getName
if (uri.isEmpty) {
- val executorPath = new File(executorSparkHome, "/bin/spark-class").getCanonicalPath
+ val executorPath = new File(executorSparkHome, "/bin/spark-class").getPath
command.setValue(s"$prefixEnv $executorPath $executorBackendName")
} else {
// Grab everything to the first '.'. We'll use that and '*' to
@@ -375,6 +375,7 @@ private[spark] class MesosSchedulerBackend(
override def error(d: SchedulerDriver, message: String) {
inClassLoader() {
logError("Mesos error: " + message)
+ markErr()
scheduler.error(message)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
index 010caff3e39b2..f9f5da9bc8df6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
@@ -106,28 +106,37 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
registerLatch.await()
return
}
+ @volatile
+ var error: Option[Exception] = None
+ // We create a new thread that will block inside `mesosDriver.run`
+ // until the scheduler exists
new Thread(Utils.getFormattedClassName(this) + "-mesos-driver") {
setDaemon(true)
-
override def run() {
- mesosDriver = newDriver
try {
+ mesosDriver = newDriver
val ret = mesosDriver.run()
logInfo("driver.run() returned with code " + ret)
if (ret != null && ret.equals(Status.DRIVER_ABORTED)) {
- System.exit(1)
+ error = Some(new SparkException("Error starting driver, DRIVER_ABORTED"))
+ markErr()
}
} catch {
case e: Exception => {
logError("driver.run() failed", e)
- System.exit(1)
+ error = Some(e)
+ markErr()
}
}
}
}.start()
registerLatch.await()
+
+ // propagate any error to the calling thread. This ensures that SparkContext creation fails
+ // without leaving a broken context that won't be able to schedule any tasks
+ error.foreach(throw _)
}
}
@@ -144,6 +153,10 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
registerLatch.countDown()
}
+ protected def markErr(): Unit = {
+ registerLatch.countDown()
+ }
+
def createResource(name: String, amount: Double, role: Option[String] = None): Resource = {
val builder = Resource.newBuilder()
.setName(name)
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala
index 0fc0fb59d861f..0f30183682469 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala
@@ -71,6 +71,13 @@ private[spark] object ApplicationsListResource {
attemptId = internalAttemptInfo.attemptId,
startTime = new Date(internalAttemptInfo.startTime),
endTime = new Date(internalAttemptInfo.endTime),
+ duration =
+ if (internalAttemptInfo.endTime > 0) {
+ internalAttemptInfo.endTime - internalAttemptInfo.startTime
+ } else {
+ 0
+ },
+ lastUpdated = new Date(internalAttemptInfo.lastUpdated),
sparkUser = internalAttemptInfo.sparkUser,
completed = internalAttemptInfo.completed
)
@@ -93,6 +100,13 @@ private[spark] object ApplicationsListResource {
attemptId = None,
startTime = new Date(internal.startTime),
endTime = new Date(internal.endTime),
+ duration =
+ if (internal.endTime > 0) {
+ internal.endTime - internal.startTime
+ } else {
+ 0
+ },
+ lastUpdated = new Date(internal.endTime),
sparkUser = internal.desc.user,
completed = completed
))
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
index 3adf5b1109af4..d116e68c17f18 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
@@ -35,6 +35,8 @@ class ApplicationAttemptInfo private[spark](
val attemptId: Option[String],
val startTime: Date,
val endTime: Date,
+ val lastUpdated: Date,
+ val duration: Long,
val sparkUser: String,
val completed: Boolean = false)
@@ -55,6 +57,7 @@ class ExecutorSummary private[spark](
val rddBlocks: Int,
val memoryUsed: Long,
val diskUsed: Long,
+ val totalCores: Int,
val maxTasks: Int,
val activeTasks: Int,
val failedTasks: Int,
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index 76aaa782b9524..024b660ce6a7b 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -255,8 +255,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo
var memoryThreshold = initialMemoryThreshold
// Memory to request as a multiple of current vector size
val memoryGrowthFactor = 1.5
- // Previous unroll memory held by this task, for releasing later (only at the very end)
- val previousMemoryReserved = currentUnrollMemoryForThisTask
+ // Keep track of pending unroll memory reserved by this method.
+ var pendingMemoryReserved = 0L
// Underlying vector for unrolling the block
var vector = new SizeTrackingVector[Any]
@@ -266,6 +266,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo
if (!keepUnrolling) {
logWarning(s"Failed to reserve initial memory threshold of " +
s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.")
+ } else {
+ pendingMemoryReserved += initialMemoryThreshold
}
// Unroll this block safely, checking whether we have exceeded our threshold periodically
@@ -278,6 +280,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo
if (currentSize >= memoryThreshold) {
val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong
keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest)
+ if (keepUnrolling) {
+ pendingMemoryReserved += amountToRequest
+ }
// New threshold is currentSize * memoryGrowthFactor
memoryThreshold += amountToRequest
}
@@ -304,10 +309,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo
// release the unroll memory yet. Instead, we transfer it to pending unroll memory
// so `tryToPut` can further transfer it to normal storage memory later.
// TODO: we can probably express this without pending unroll memory (SPARK-10907)
- val amountToTransferToPending = currentUnrollMemoryForThisTask - previousMemoryReserved
- unrollMemoryMap(taskAttemptId) -= amountToTransferToPending
+ unrollMemoryMap(taskAttemptId) -= pendingMemoryReserved
pendingUnrollMemoryMap(taskAttemptId) =
- pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + amountToTransferToPending
+ pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + pendingMemoryReserved
}
} else {
// Otherwise, if we return an iterator, we can only release the unroll memory when
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index cf45414c4f786..6cc30eeaf5d82 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -114,6 +114,8 @@ private[spark] class SparkUI private (
attemptId = None,
startTime = new Date(startTime),
endTime = new Date(-1),
+ duration = 0,
+ lastUpdated = new Date(startTime),
sparkUser = "",
completed = false
))
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index 1949c4b3cbf42..4ebee9093d41c 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -157,11 +157,22 @@ private[spark] object UIUtils extends Logging {
def commonHeaderNodes: Seq[Node] = {
+
+
+
+
+
+
+
+
+
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
index e36b96b3e6978..e1f754999912b 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
@@ -75,6 +75,7 @@ private[ui] class ExecutorsPage(
RDD Blocks
Storage Memory
Disk Used
+
Cores
Active Tasks
Failed Tasks
Complete Tasks
@@ -131,6 +132,7 @@ private[ui] class ExecutorsPage(
@@ -174,6 +176,7 @@ private[ui] class ExecutorsPage(
val maximumMemory = execInfo.map(_.maxMemory).sum
val memoryUsed = execInfo.map(_.memoryUsed).sum
val diskUsed = execInfo.map(_.diskUsed).sum
+ val totalCores = execInfo.map(_.totalCores).sum
val totalInputBytes = execInfo.map(_.totalInputBytes).sum
val totalShuffleRead = execInfo.map(_.totalShuffleRead).sum
val totalShuffleWrite = execInfo.map(_.totalShuffleWrite).sum
@@ -188,6 +191,7 @@ private[ui] class ExecutorsPage(
{Utils.bytesToString(diskUsed)}
+
{totalCores}
{taskData(execInfo.map(_.maxTasks).sum,
execInfo.map(_.activeTasks).sum,
execInfo.map(_.failedTasks).sum,
@@ -211,6 +215,7 @@ private[ui] class ExecutorsPage(
RDD Blocks
Storage Memory
Disk Used
+
Cores
Active Tasks
Failed Tasks
Complete Tasks
@@ -305,6 +310,7 @@ private[spark] object ExecutorsPage {
val memUsed = status.memUsed
val maxMem = status.maxMem
val diskUsed = status.diskUsed
+ val totalCores = listener.executorToTotalCores.getOrElse(execId, 0)
val maxTasks = listener.executorToTasksMax.getOrElse(execId, 0)
val activeTasks = listener.executorToTasksActive.getOrElse(execId, 0)
val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0)
@@ -323,6 +329,7 @@ private[spark] object ExecutorsPage {
rddBlocks,
memUsed,
diskUsed,
+ totalCores,
maxTasks,
activeTasks,
failedTasks,
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
index a9e926b158780..dcfebe92ed805 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
@@ -45,6 +45,7 @@ private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "exec
@DeveloperApi
class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf)
extends SparkListener {
+ val executorToTotalCores = HashMap[String, Int]()
val executorToTasksMax = HashMap[String, Int]()
val executorToTasksActive = HashMap[String, Int]()
val executorToTasksComplete = HashMap[String, Int]()
@@ -65,8 +66,8 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar
override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = synchronized {
val eid = executorAdded.executorId
executorToLogUrls(eid) = executorAdded.executorInfo.logUrlMap
- executorToTasksMax(eid) =
- executorAdded.executorInfo.totalCores / conf.getInt("spark.task.cpus", 1)
+ executorToTotalCores(eid) = executorAdded.executorInfo.totalCores
+ executorToTasksMax(eid) = executorToTotalCores(eid) / conf.getInt("spark.task.cpus", 1)
executorIdToData(eid) = ExecutorUIData(executorAdded.time)
}
diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala
index d484cec7ae384..1bf6f821e9b31 100644
--- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala
+++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util
import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
import org.apache.commons.lang3.SystemUtils
@@ -59,17 +60,21 @@ private[spark] class Benchmark(
}
println
- val firstRate = results.head.avgRate
+ val firstBest = results.head.bestMs
+ val firstAvg = results.head.avgMs
// The results are going to be processor specific so it is useful to include that.
println(Benchmark.getProcessorName())
- printf("%-30s %16s %16s %14s\n", name + ":", "Avg Time(ms)", "Avg Rate(M/s)", "Relative Rate")
- println("-------------------------------------------------------------------------------")
- results.zip(benchmarks).foreach { r =>
- printf("%-30s %16s %16s %14s\n",
- r._2.name,
- "%10.2f" format r._1.avgMs,
- "%10.2f" format r._1.avgRate,
- "%6.2f X" format (r._1.avgRate / firstRate))
+ printf("%-35s %16s %12s %13s %10s\n", name + ":", "Best/Avg Time(ms)", "Rate(M/s)",
+ "Per Row(ns)", "Relative")
+ println("-----------------------------------------------------------------------------------" +
+ "--------")
+ results.zip(benchmarks).foreach { case (result, benchmark) =>
+ printf("%-35s %16s %12s %13s %10s\n",
+ benchmark.name,
+ "%5.0f / %4.0f" format (result.bestMs, result.avgMs),
+ "%10.1f" format result.bestRate,
+ "%6.1f" format (1000 / result.bestRate),
+ "%3.1fX" format (firstBest / result.bestMs))
}
println
// scalastyle:on
@@ -78,7 +83,7 @@ private[spark] class Benchmark(
private[spark] object Benchmark {
case class Case(name: String, fn: Int => Unit)
- case class Result(avgMs: Double, avgRate: Double)
+ case class Result(avgMs: Double, bestRate: Double, bestMs: Double)
/**
* This should return a user helpful processor information. Getting at this depends on the OS.
@@ -99,22 +104,27 @@ private[spark] object Benchmark {
* the rate of the function.
*/
def measure(num: Long, iters: Int, outputPerIteration: Boolean)(f: Int => Unit): Result = {
- var totalTime = 0L
+ val runTimes = ArrayBuffer[Long]()
for (i <- 0 until iters + 1) {
val start = System.nanoTime()
f(i)
val end = System.nanoTime()
- if (i != 0) totalTime += end - start
+ val runTime = end - start
+ if (i > 0) {
+ runTimes += runTime
+ }
if (outputPerIteration) {
// scalastyle:off
- println(s"Iteration $i took ${(end - start) / 1000} microseconds")
+ println(s"Iteration $i took ${runTime / 1000} microseconds")
// scalastyle:on
}
}
- Result(totalTime.toDouble / 1000000 / iters, num * iters / (totalTime.toDouble / 1000))
+ val best = runTimes.min
+ val avg = runTimes.sum / iters
+ Result(avg / 1000000.0, num / (best / 1000.0), best / 1000000.0)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 1829c34eb1cb1..09d955300a64a 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -290,7 +290,8 @@ private[spark] object JsonProtocol {
("Update" -> accumulableInfo.update.map { v => accumValueToJson(name, v) }) ~
("Value" -> accumulableInfo.value.map { v => accumValueToJson(name, v) }) ~
("Internal" -> accumulableInfo.internal) ~
- ("Count Failed Values" -> accumulableInfo.countFailedValues)
+ ("Count Failed Values" -> accumulableInfo.countFailedValues) ~
+ ("Metadata" -> accumulableInfo.metadata)
}
/**
@@ -728,7 +729,8 @@ private[spark] object JsonProtocol {
val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) }
val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false)
val countFailedValues = (json \ "Count Failed Values").extractOpt[Boolean].getOrElse(false)
- new AccumulableInfo(id, name, update, value, internal, countFailedValues)
+ val metadata = (json \ "Metadata").extractOpt[String]
+ new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata)
}
/**
@@ -809,8 +811,8 @@ private[spark] object JsonProtocol {
Utils.jsonOption(json \ "Input Metrics").foreach { inJson =>
val readMethod = DataReadMethod.withName((inJson \ "Data Read Method").extract[String])
val inputMetrics = metrics.registerInputMetrics(readMethod)
- inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long])
- inputMetrics.incRecordsRead((inJson \ "Records Read").extractOpt[Long].getOrElse(0L))
+ inputMetrics.incBytesReadInternal((inJson \ "Bytes Read").extract[Long])
+ inputMetrics.incRecordsReadInternal((inJson \ "Records Read").extractOpt[Long].getOrElse(0L))
}
// Updated blocks
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
index 0328e63e45439..eb1da8e1b43eb 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -75,6 +75,9 @@ public void testBasicSorting() throws Exception {
// Write the records into the data page and store pointers into the sorter
long position = dataPage.getBaseOffset();
for (String str : dataToSort) {
+ if (!sorter.hasSpaceForAnotherRecord()) {
+ sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2));
+ }
final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position);
final byte[] strBytes = str.getBytes("utf-8");
Platform.putInt(baseObject, position, strBytes.length);
@@ -114,6 +117,9 @@ public void testSortingManyNumbers() throws Exception {
int[] numbersToSort = new int[128000];
Random random = new Random(16);
for (int i = 0; i < numbersToSort.length; i++) {
+ if (!sorter.hasSpaceForAnotherRecord()) {
+ sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2));
+ }
numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
sorter.insertRecord(0, numbersToSort[i]);
}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index 93efd033eb940..8e557ec0ab0b4 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -111,6 +111,9 @@ public int compare(long prefix1, long prefix2) {
// Given a page of records, insert those records into the sorter one-by-one:
position = dataPage.getBaseOffset();
for (int i = 0; i < dataToSort.length; i++) {
+ if (!sorter.hasSpaceForAnotherRecord()) {
+ sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2 * 2));
+ }
// position now points to the start of a record (which holds its length).
final int recordLength = Platform.getInt(baseObject, position);
final long address = memoryManager.encodePageNumberAndOffset(dataPage, position);
diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json
index d575bf2f284b9..5bbb4ceb97228 100644
--- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json
@@ -4,6 +4,8 @@
"attempts" : [ {
"startTime" : "2015-05-06T13:03:00.893GMT",
"endTime" : "2015-05-06T13:03:11.398GMT",
+ "lastUpdated" : "",
+ "duration" : 10505,
"sparkUser" : "irashid",
"completed" : true
} ]
@@ -14,12 +16,16 @@
"attemptId" : "2",
"startTime" : "2015-05-06T13:03:00.893GMT",
"endTime" : "2015-05-06T13:03:00.950GMT",
+ "lastUpdated" : "",
+ "duration" : 57,
"sparkUser" : "irashid",
"completed" : true
}, {
"attemptId" : "1",
"startTime" : "2015-05-06T13:03:00.880GMT",
"endTime" : "2015-05-06T13:03:00.890GMT",
+ "lastUpdated" : "",
+ "duration" : 10,
"sparkUser" : "irashid",
"completed" : true
} ]
@@ -30,12 +36,16 @@
"attemptId" : "2",
"startTime" : "2015-03-17T23:11:50.242GMT",
"endTime" : "2015-03-17T23:12:25.177GMT",
+ "lastUpdated" : "",
+ "duration" : 34935,
"sparkUser" : "irashid",
"completed" : true
}, {
"attemptId" : "1",
"startTime" : "2015-03-16T19:25:10.242GMT",
"endTime" : "2015-03-16T19:25:45.177GMT",
+ "lastUpdated" : "",
+ "duration" : 34935,
"sparkUser" : "irashid",
"completed" : true
} ]
@@ -45,6 +55,8 @@
"attempts" : [ {
"startTime" : "2015-02-28T00:02:38.277GMT",
"endTime" : "2015-02-28T00:02:46.912GMT",
+ "lastUpdated" : "",
+ "duration" : 8635,
"sparkUser" : "irashid",
"completed" : true
} ]
@@ -54,6 +66,8 @@
"attempts" : [ {
"startTime" : "2015-02-03T16:42:59.720GMT",
"endTime" : "2015-02-03T16:43:08.731GMT",
+ "lastUpdated" : "",
+ "duration" : 9011,
"sparkUser" : "irashid",
"completed" : true
} ]
@@ -63,7 +77,9 @@
"attempts" : [ {
"startTime" : "2015-02-03T16:42:38.277GMT",
"endTime" : "2015-02-03T16:42:46.912GMT",
+ "lastUpdated" : "",
+ "duration" : 8635,
"sparkUser" : "irashid",
"completed" : true
} ]
-} ]
\ No newline at end of file
+} ]
diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json
index d575bf2f284b9..5bbb4ceb97228 100644
--- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json
@@ -4,6 +4,8 @@
"attempts" : [ {
"startTime" : "2015-05-06T13:03:00.893GMT",
"endTime" : "2015-05-06T13:03:11.398GMT",
+ "lastUpdated" : "",
+ "duration" : 10505,
"sparkUser" : "irashid",
"completed" : true
} ]
@@ -14,12 +16,16 @@
"attemptId" : "2",
"startTime" : "2015-05-06T13:03:00.893GMT",
"endTime" : "2015-05-06T13:03:00.950GMT",
+ "lastUpdated" : "",
+ "duration" : 57,
"sparkUser" : "irashid",
"completed" : true
}, {
"attemptId" : "1",
"startTime" : "2015-05-06T13:03:00.880GMT",
"endTime" : "2015-05-06T13:03:00.890GMT",
+ "lastUpdated" : "",
+ "duration" : 10,
"sparkUser" : "irashid",
"completed" : true
} ]
@@ -30,12 +36,16 @@
"attemptId" : "2",
"startTime" : "2015-03-17T23:11:50.242GMT",
"endTime" : "2015-03-17T23:12:25.177GMT",
+ "lastUpdated" : "",
+ "duration" : 34935,
"sparkUser" : "irashid",
"completed" : true
}, {
"attemptId" : "1",
"startTime" : "2015-03-16T19:25:10.242GMT",
"endTime" : "2015-03-16T19:25:45.177GMT",
+ "lastUpdated" : "",
+ "duration" : 34935,
"sparkUser" : "irashid",
"completed" : true
} ]
@@ -45,6 +55,8 @@
"attempts" : [ {
"startTime" : "2015-02-28T00:02:38.277GMT",
"endTime" : "2015-02-28T00:02:46.912GMT",
+ "lastUpdated" : "",
+ "duration" : 8635,
"sparkUser" : "irashid",
"completed" : true
} ]
@@ -54,6 +66,8 @@
"attempts" : [ {
"startTime" : "2015-02-03T16:42:59.720GMT",
"endTime" : "2015-02-03T16:43:08.731GMT",
+ "lastUpdated" : "",
+ "duration" : 9011,
"sparkUser" : "irashid",
"completed" : true
} ]
@@ -63,7 +77,9 @@
"attempts" : [ {
"startTime" : "2015-02-03T16:42:38.277GMT",
"endTime" : "2015-02-03T16:42:46.912GMT",
+ "lastUpdated" : "",
+ "duration" : 8635,
"sparkUser" : "irashid",
"completed" : true
} ]
-} ]
\ No newline at end of file
+} ]
diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json
index 94f8aeac55b5d..9d5d224e55176 100644
--- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json
@@ -4,6 +4,7 @@
"rddBlocks" : 8,
"memoryUsed" : 28000128,
"diskUsed" : 0,
+ "totalCores" : 0,
"maxTasks" : 0,
"activeTasks" : 0,
"failedTasks" : 1,
diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json
index 483632a3956ed..3f80a529a08b9 100644
--- a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json
@@ -4,7 +4,9 @@
"attempts" : [ {
"startTime" : "2015-02-03T16:42:38.277GMT",
"endTime" : "2015-02-03T16:42:46.912GMT",
+ "lastUpdated" : "",
+ "duration" : 8635,
"sparkUser" : "irashid",
"completed" : true
} ]
-} ]
\ No newline at end of file
+} ]
diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json
index 4b85690fd9199..508bdc17efe9f 100644
--- a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json
@@ -4,6 +4,8 @@
"attempts" : [ {
"startTime" : "2015-02-03T16:42:59.720GMT",
"endTime" : "2015-02-03T16:43:08.731GMT",
+ "lastUpdated" : "",
+ "duration" : 9011,
"sparkUser" : "irashid",
"completed" : true
} ]
@@ -13,7 +15,9 @@
"attempts" : [ {
"startTime" : "2015-02-03T16:42:38.277GMT",
"endTime" : "2015-02-03T16:42:46.912GMT",
+ "lastUpdated" : "",
+ "duration" : 8635,
"sparkUser" : "irashid",
"completed" : true
} ]
-} ]
\ No newline at end of file
+} ]
diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json
index 15c2de8ef99ea..5dca7d73de0cc 100644
--- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json
@@ -4,6 +4,8 @@
"attempts" : [ {
"startTime" : "2015-05-06T13:03:00.893GMT",
"endTime" : "2015-05-06T13:03:11.398GMT",
+ "lastUpdated" : "",
+ "duration" : 10505,
"sparkUser" : "irashid",
"completed" : true
} ]
@@ -14,12 +16,16 @@
"attemptId" : "2",
"startTime" : "2015-05-06T13:03:00.893GMT",
"endTime" : "2015-05-06T13:03:00.950GMT",
+ "lastUpdated" : "",
+ "duration" : 57,
"sparkUser" : "irashid",
"completed" : true
}, {
"attemptId" : "1",
"startTime" : "2015-05-06T13:03:00.880GMT",
"endTime" : "2015-05-06T13:03:00.890GMT",
+ "lastUpdated" : "",
+ "duration" : 10,
"sparkUser" : "irashid",
"completed" : true
} ]
@@ -30,12 +36,16 @@
"attemptId" : "2",
"startTime" : "2015-03-17T23:11:50.242GMT",
"endTime" : "2015-03-17T23:12:25.177GMT",
+ "lastUpdated" : "",
+ "duration" : 34935,
"sparkUser" : "irashid",
"completed" : true
}, {
"attemptId" : "1",
"startTime" : "2015-03-16T19:25:10.242GMT",
"endTime" : "2015-03-16T19:25:45.177GMT",
+ "lastUpdated" : "",
+ "duration" : 34935,
"sparkUser" : "irashid",
"completed" : true
} ]
@@ -46,8 +56,10 @@
{
"startTime": "2015-02-28T00:02:38.277GMT",
"endTime": "2015-02-28T00:02:46.912GMT",
+ "lastUpdated" : "",
+ "duration" : 8635,
"sparkUser": "irashid",
"completed": true
}
]
-} ]
\ No newline at end of file
+} ]
diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json
index 07489ad96414a..cca32c791074f 100644
--- a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json
@@ -4,7 +4,9 @@
"attempts" : [ {
"startTime" : "2015-02-03T16:42:59.720GMT",
"endTime" : "2015-02-03T16:43:08.731GMT",
+ "lastUpdated" : "",
+ "duration" : 9011,
"sparkUser" : "irashid",
"completed" : true
} ]
-}
\ No newline at end of file
+}
diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json
index 8f3d7160c723f..1ea1779e8369d 100644
--- a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json
@@ -5,13 +5,17 @@
"attemptId" : "2",
"startTime" : "2015-03-17T23:11:50.242GMT",
"endTime" : "2015-03-17T23:12:25.177GMT",
+ "lastUpdated" : "",
+ "duration" : 34935,
"sparkUser" : "irashid",
"completed" : true
}, {
"attemptId" : "1",
"startTime" : "2015-03-16T19:25:10.242GMT",
"endTime" : "2015-03-16T19:25:45.177GMT",
+ "lastUpdated" : "",
+ "duration" : 34935,
"sparkUser" : "irashid",
"completed" : true
} ]
-}
\ No newline at end of file
+}
diff --git a/core/src/test/resources/spark-events/local-1422981759269/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1422981759269
similarity index 100%
rename from core/src/test/resources/spark-events/local-1422981759269/EVENT_LOG_1
rename to core/src/test/resources/spark-events/local-1422981759269
diff --git a/core/src/test/resources/spark-events/local-1422981759269/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1422981759269/APPLICATION_COMPLETE
deleted file mode 100755
index e69de29bb2d1d..0000000000000
diff --git a/core/src/test/resources/spark-events/local-1422981759269/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1422981759269/SPARK_VERSION_1.2.0
deleted file mode 100755
index e69de29bb2d1d..0000000000000
diff --git a/core/src/test/resources/spark-events/local-1422981780767/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1422981780767
similarity index 100%
rename from core/src/test/resources/spark-events/local-1422981780767/EVENT_LOG_1
rename to core/src/test/resources/spark-events/local-1422981780767
diff --git a/core/src/test/resources/spark-events/local-1422981780767/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1422981780767/APPLICATION_COMPLETE
deleted file mode 100755
index e69de29bb2d1d..0000000000000
diff --git a/core/src/test/resources/spark-events/local-1422981780767/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1422981780767/SPARK_VERSION_1.2.0
deleted file mode 100755
index e69de29bb2d1d..0000000000000
diff --git a/core/src/test/resources/spark-events/local-1425081759269/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1425081759269
similarity index 100%
rename from core/src/test/resources/spark-events/local-1425081759269/EVENT_LOG_1
rename to core/src/test/resources/spark-events/local-1425081759269
diff --git a/core/src/test/resources/spark-events/local-1425081759269/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1425081759269/APPLICATION_COMPLETE
deleted file mode 100755
index e69de29bb2d1d..0000000000000
diff --git a/core/src/test/resources/spark-events/local-1425081759269/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1425081759269/SPARK_VERSION_1.2.0
deleted file mode 100755
index e69de29bb2d1d..0000000000000
diff --git a/core/src/test/resources/spark-events/local-1426533911241/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1426533911241
similarity index 100%
rename from core/src/test/resources/spark-events/local-1426533911241/EVENT_LOG_1
rename to core/src/test/resources/spark-events/local-1426533911241
diff --git a/core/src/test/resources/spark-events/local-1426533911241/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1426533911241/APPLICATION_COMPLETE
deleted file mode 100755
index e69de29bb2d1d..0000000000000
diff --git a/core/src/test/resources/spark-events/local-1426533911241/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1426533911241/SPARK_VERSION_1.2.0
deleted file mode 100755
index e69de29bb2d1d..0000000000000
diff --git a/core/src/test/resources/spark-events/local-1426633911242/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1426633911242
similarity index 100%
rename from core/src/test/resources/spark-events/local-1426633911242/EVENT_LOG_1
rename to core/src/test/resources/spark-events/local-1426633911242
diff --git a/core/src/test/resources/spark-events/local-1426633911242/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1426633911242/APPLICATION_COMPLETE
deleted file mode 100755
index e69de29bb2d1d..0000000000000
diff --git a/core/src/test/resources/spark-events/local-1426633911242/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1426633911242/SPARK_VERSION_1.2.0
deleted file mode 100755
index e69de29bb2d1d..0000000000000
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 193c0a2479da6..4d49fe5159850 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -307,6 +307,8 @@ private[spark] object AccumulatorSuite {
val listener = new SaveInfoListener
sc.addSparkListener(listener)
testBody
+ // wait until all events have been processed before proceeding to assert things
+ sc.listenerBus.waitUntilEmpty(10 * 1000)
val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values)
val isSet = accums.exists { a =>
a.name == Some(PEAK_EXECUTION_MEMORY) && a.value.exists(_.asInstanceOf[Long] > 0L)
@@ -321,35 +323,60 @@ private[spark] object AccumulatorSuite {
* A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs.
*/
private class SaveInfoListener extends SparkListener {
- private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo]
- private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo]
- private var jobCompletionCallback: (Int => Unit) = null // parameter is job ID
+ type StageId = Int
+ type StageAttemptId = Int
- // Accesses must be synchronized to ensure failures in `jobCompletionCallback` are propagated
+ private val completedStageInfos = new ArrayBuffer[StageInfo]
+ private val completedTaskInfos =
+ new mutable.HashMap[(StageId, StageAttemptId), ArrayBuffer[TaskInfo]]
+
+ // Callback to call when a job completes. Parameter is job ID.
@GuardedBy("this")
+ private var jobCompletionCallback: () => Unit = null
+ private var calledJobCompletionCallback: Boolean = false
private var exception: Throwable = null
def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq
- def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq
+ def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.values.flatten.toSeq
+ def getCompletedTaskInfos(stageId: StageId, stageAttemptId: StageAttemptId): Seq[TaskInfo] =
+ completedTaskInfos.get((stageId, stageAttemptId)).getOrElse(Seq.empty[TaskInfo])
- /** Register a callback to be called on job end. */
- def registerJobCompletionCallback(callback: (Int => Unit)): Unit = {
- jobCompletionCallback = callback
+ /**
+ * If `jobCompletionCallback` is set, block until the next call has finished.
+ * If the callback failed with an exception, throw it.
+ */
+ def awaitNextJobCompletion(): Unit = synchronized {
+ if (jobCompletionCallback != null) {
+ while (!calledJobCompletionCallback) {
+ wait()
+ }
+ calledJobCompletionCallback = false
+ if (exception != null) {
+ exception = null
+ throw exception
+ }
+ }
}
- /** Throw a stored exception, if any. */
- def maybeThrowException(): Unit = synchronized {
- if (exception != null) { throw exception }
+ /**
+ * Register a callback to be called on job end.
+ * A call to this should be followed by [[awaitNextJobCompletion]].
+ */
+ def registerJobCompletionCallback(callback: () => Unit): Unit = {
+ jobCompletionCallback = callback
}
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized {
if (jobCompletionCallback != null) {
try {
- jobCompletionCallback(jobEnd.jobId)
+ jobCompletionCallback()
} catch {
// Store any exception thrown here so we can throw them later in the main thread.
// Otherwise, if `jobCompletionCallback` threw something it wouldn't fail the test.
case NonFatal(e) => exception = e
+ } finally {
+ calledJobCompletionCallback = true
+ notify()
}
}
}
@@ -359,7 +386,8 @@ private class SaveInfoListener extends SparkListener {
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
- completedTaskInfos += taskEnd.taskInfo
+ completedTaskInfos.getOrElseUpdate(
+ (taskEnd.stageId, taskEnd.stageAttemptId), new ArrayBuffer[TaskInfo]) += taskEnd.taskInfo
}
}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index 390764ba242fd..ce35856dce3f7 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -512,6 +512,27 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
assert(rdd.isCheckpointedAndMaterialized === true)
assert(rdd.partitions.size === 0)
}
+
+ runTest("checkpointAllMarkedAncestors") { reliableCheckpoint: Boolean =>
+ testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = true)
+ testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = false)
+ }
+
+ private def testCheckpointAllMarkedAncestors(
+ reliableCheckpoint: Boolean, checkpointAllMarkedAncestors: Boolean): Unit = {
+ sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, checkpointAllMarkedAncestors.toString)
+ try {
+ val rdd1 = sc.parallelize(1 to 10)
+ checkpoint(rdd1, reliableCheckpoint)
+ val rdd2 = rdd1.map(_ + 1)
+ checkpoint(rdd2, reliableCheckpoint)
+ rdd2.count()
+ assert(rdd1.isCheckpointed === checkpointAllMarkedAncestors)
+ assert(rdd2.isCheckpointed === true)
+ } finally {
+ sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, null)
+ }
+ }
}
/** RDD partition that has large serialized size. */
diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
index b69f70cab3d3f..c426bb7a4e809 100644
--- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.scheduler.AccumulableInfo
+import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockId, BlockStatus}
@@ -160,7 +161,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
iter
}
// Register asserts in job completion callback to avoid flakiness
- listener.registerJobCompletionCallback { _ =>
+ listener.registerJobCompletionCallback { () =>
val stageInfos = listener.getCompletedStageInfos
val taskInfos = listener.getCompletedTaskInfos
assert(stageInfos.size === 1)
@@ -179,6 +180,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
}
rdd.count()
+ listener.awaitNextJobCompletion()
}
test("internal accumulators in multiple stages") {
@@ -205,7 +207,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
iter
}
// Register asserts in job completion callback to avoid flakiness
- listener.registerJobCompletionCallback { _ =>
+ listener.registerJobCompletionCallback { () =>
// We ran 3 stages, and the accumulator values should be distinct
val stageInfos = listener.getCompletedStageInfos
assert(stageInfos.size === 3)
@@ -220,13 +222,66 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
rdd.count()
}
- // TODO: these two tests are incorrect; they don't actually trigger stage retries (SPARK-13053).
- ignore("internal accumulators in fully resubmitted stages") {
- testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks
- }
+ test("internal accumulators in resubmitted stages") {
+ val listener = new SaveInfoListener
+ val numPartitions = 10
+ sc = new SparkContext("local", "test")
+ sc.addSparkListener(listener)
+
+ // Simulate fetch failures in order to trigger a stage retry. Here we run 1 job with
+ // 2 stages. On the second stage, we trigger a fetch failure on the first stage attempt.
+ // This should retry both stages in the scheduler. Note that we only want to fail the
+ // first stage attempt because we want the stage to eventually succeed.
+ val x = sc.parallelize(1 to 100, numPartitions)
+ .mapPartitions { iter => TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1; iter }
+ .groupBy(identity)
+ val sid = x.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle.shuffleId
+ val rdd = x.mapPartitionsWithIndex { case (i, iter) =>
+ // Fail the first stage attempt. Here we use the task attempt ID to determine this.
+ // This job runs 2 stages, and we're in the second stage. Therefore, any task attempt
+ // ID that's < 2 * numPartitions belongs to the first attempt of this stage.
+ val taskContext = TaskContext.get()
+ val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2
+ if (isFirstStageAttempt) {
+ throw new FetchFailedException(
+ SparkEnv.get.blockManager.blockManagerId,
+ sid,
+ taskContext.partitionId(),
+ taskContext.partitionId(),
+ "simulated fetch failure")
+ } else {
+ iter
+ }
+ }
- ignore("internal accumulators in partially resubmitted stages") {
- testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset
+ // Register asserts in job completion callback to avoid flakiness
+ listener.registerJobCompletionCallback { () =>
+ val stageInfos = listener.getCompletedStageInfos
+ assert(stageInfos.size === 4) // 1 shuffle map stage + 1 result stage, both are retried
+ val mapStageId = stageInfos.head.stageId
+ val mapStageInfo1stAttempt = stageInfos.head
+ val mapStageInfo2ndAttempt = {
+ stageInfos.tail.find(_.stageId == mapStageId).getOrElse {
+ fail("expected two attempts of the same shuffle map stage.")
+ }
+ }
+ val stageAccum1stAttempt = findTestAccum(mapStageInfo1stAttempt.accumulables.values)
+ val stageAccum2ndAttempt = findTestAccum(mapStageInfo2ndAttempt.accumulables.values)
+ // Both map stages should have succeeded, since the fetch failure happened in the
+ // result stage, not the map stage. This means we should get the accumulator updates
+ // from all partitions.
+ assert(stageAccum1stAttempt.value.get.asInstanceOf[Long] === numPartitions)
+ assert(stageAccum2ndAttempt.value.get.asInstanceOf[Long] === numPartitions)
+ // Because this test resubmitted the map stage with all missing partitions, we should have
+ // created a fresh set of internal accumulators in the 2nd stage attempt. Assert this is
+ // the case by comparing the accumulator IDs between the two attempts.
+ // Note: it would be good to also test the case where the map stage is resubmitted where
+ // only a subset of the original partitions are missing. However, this scenario is very
+ // difficult to construct without potentially introducing flakiness.
+ assert(stageAccum1stAttempt.id != stageAccum2ndAttempt.id)
+ }
+ rdd.count()
+ listener.awaitNextJobCompletion()
}
test("internal accumulators are registered for cleanups") {
@@ -257,63 +312,6 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
}
}
- /**
- * Test whether internal accumulators are merged properly if some tasks fail.
- * TODO: make this actually retry the stage (SPARK-13053).
- */
- private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = {
- val listener = new SaveInfoListener
- val numPartitions = 10
- val numFailedPartitions = (0 until numPartitions).count(failCondition)
- // This says use 1 core and retry tasks up to 2 times
- sc = new SparkContext("local[1, 2]", "test")
- sc.addSparkListener(listener)
- val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) =>
- val taskContext = TaskContext.get()
- taskContext.taskMetrics().getAccum(TEST_ACCUM) += 1
- // Fail the first attempts of a subset of the tasks
- if (failCondition(i) && taskContext.attemptNumber() == 0) {
- throw new Exception("Failing a task intentionally.")
- }
- iter
- }
- // Register asserts in job completion callback to avoid flakiness
- listener.registerJobCompletionCallback { _ =>
- val stageInfos = listener.getCompletedStageInfos
- val taskInfos = listener.getCompletedTaskInfos
- assert(stageInfos.size === 1)
- assert(taskInfos.size === numPartitions + numFailedPartitions)
- val stageAccum = findTestAccum(stageInfos.head.accumulables.values)
- // If all partitions failed, then we would resubmit the whole stage again and create a
- // fresh set of internal accumulators. Otherwise, these internal accumulators do count
- // failed values, so we must include the failed values.
- val expectedAccumValue =
- if (numPartitions == numFailedPartitions) {
- numPartitions
- } else {
- numPartitions + numFailedPartitions
- }
- assert(stageAccum.value.get.asInstanceOf[Long] === expectedAccumValue)
- val taskAccumValues = taskInfos.flatMap { taskInfo =>
- if (!taskInfo.failed) {
- // If a task succeeded, its update value should always be 1
- val taskAccum = findTestAccum(taskInfo.accumulables)
- assert(taskAccum.update.isDefined)
- assert(taskAccum.update.get.asInstanceOf[Long] === 1L)
- assert(taskAccum.value.isDefined)
- Some(taskAccum.value.get.asInstanceOf[Long])
- } else {
- // If a task failed, we should not get its accumulator values
- assert(taskInfo.accumulables.isEmpty)
- None
- }
- }
- assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
- }
- rdd.count()
- listener.maybeThrowException()
- }
-
/**
* A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup.
*/
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index e13a442463e8d..c347ab8dc8020 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -22,7 +22,7 @@ import java.util.concurrent.Semaphore
import scala.concurrent.Await
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
-import scala.concurrent.future
+import scala.concurrent.Future
import org.scalatest.BeforeAndAfter
import org.scalatest.Matchers
@@ -103,7 +103,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
val rdd1 = rdd.map(x => x)
- future {
+ Future {
taskStartedSemaphore.acquire()
sc.cancelAllJobs()
taskCancelledSemaphore.release(100000)
@@ -126,7 +126,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
})
// jobA is the one to be cancelled.
- val jobA = future {
+ val jobA = Future {
sc.setJobGroup("jobA", "this is a job to be cancelled")
sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
}
@@ -191,7 +191,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
})
// jobA is the one to be cancelled.
- val jobA = future {
+ val jobA = Future {
sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true)
sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(100000); i }.count()
}
@@ -231,7 +231,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
val f2 = rdd.countAsync()
// Kill one of the action.
- future {
+ Future {
sem1.acquire()
f1.cancel()
JobCancellationSuite.twoJobsSharingStageSemaphore.release(10)
@@ -247,7 +247,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
// Cancel before launching any tasks
{
val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync()
- future { f.cancel() }
+ Future { f.cancel() }
val e = intercept[SparkException] { f.get() }
assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
}
@@ -263,7 +263,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
})
val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync()
- future {
+ Future {
// Wait until some tasks were launched before we cancel the job.
sem.acquire()
f.cancel()
@@ -277,7 +277,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
// Cancel before launching any tasks
{
val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000)
- future { f.cancel() }
+ Future { f.cancel() }
val e = intercept[SparkException] { f.get() }
assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
}
@@ -292,7 +292,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
}
})
val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000)
- future {
+ Future {
sem.acquire()
f.cancel()
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
index fdada0777f9a9..b7ff5c9e8c0d3 100644
--- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
@@ -447,7 +447,23 @@ class StandaloneDynamicAllocationSuite
apps = getApplications()
// kill executor successfully
assert(apps.head.executors.size === 1)
+ }
+ test("initial executor limit") {
+ val initialExecutorLimit = 1
+ val myConf = appConf
+ .set("spark.dynamicAllocation.enabled", "true")
+ .set("spark.shuffle.service.enabled", "true")
+ .set("spark.dynamicAllocation.initialExecutors", initialExecutorLimit.toString)
+ sc = new SparkContext(myConf)
+ val appId = sc.applicationId
+ eventually(timeout(10.seconds), interval(10.millis)) {
+ val apps = getApplications()
+ assert(apps.size === 1)
+ assert(apps.head.id === appId)
+ assert(apps.head.executors.size === initialExecutorLimit)
+ assert(apps.head.getExecutorLimit === initialExecutorLimit)
+ }
}
// ===============================
@@ -540,7 +556,6 @@ class StandaloneDynamicAllocationSuite
val missingExecutors = masterExecutors.toSet.diff(driverExecutors.toSet).toSeq.sorted
missingExecutors.foreach { id =>
// Fake an executor registration so the driver knows about us
- val port = System.currentTimeMillis % 65536
val endpointRef = mock(classOf[RpcEndpointRef])
val mockAddress = mock(classOf[RpcAddress])
when(endpointRef.address).thenReturn(mockAddress)
diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala
index eb794b6739d5e..658779360b7a5 100644
--- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.deploy.client
-import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+import java.util.concurrent.ConcurrentLinkedQueue
+
+import scala.collection.JavaConverters._
import scala.concurrent.duration._
import org.scalatest.BeforeAndAfterAll
@@ -165,14 +167,14 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd
/** Application Listener to collect events */
private class AppClientCollector extends AppClientListener with Logging {
- val connectedIdList = new ArrayBuffer[String] with SynchronizedBuffer[String]
+ val connectedIdList = new ConcurrentLinkedQueue[String]()
@volatile var disconnectedCount: Int = 0
- val deadReasonList = new ArrayBuffer[String] with SynchronizedBuffer[String]
- val execAddedList = new ArrayBuffer[String] with SynchronizedBuffer[String]
- val execRemovedList = new ArrayBuffer[String] with SynchronizedBuffer[String]
+ val deadReasonList = new ConcurrentLinkedQueue[String]()
+ val execAddedList = new ConcurrentLinkedQueue[String]()
+ val execRemovedList = new ConcurrentLinkedQueue[String]()
def connected(id: String): Unit = {
- connectedIdList += id
+ connectedIdList.add(id)
}
def disconnected(): Unit = {
@@ -182,7 +184,7 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd
}
def dead(reason: String): Unit = {
- deadReasonList += reason
+ deadReasonList.add(reason)
}
def executorAdded(
@@ -191,11 +193,11 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd
hostPort: String,
cores: Int,
memory: Int): Unit = {
- execAddedList += id
+ execAddedList.add(id)
}
def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit = {
- execRemovedList += id
+ execRemovedList.add(id)
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
index 6cbf911395a84..3baa2e2ddad31 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
@@ -69,7 +69,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
new File(logPath)
}
- test("Parse new and old application logs") {
+ test("Parse application logs") {
val provider = new FsHistoryProvider(createTestConf())
// Write a new-style application log.
@@ -95,26 +95,11 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
None)
)
- // Write an old-style application log.
- val oldAppComplete = writeOldLog("old1", "1.0", None, true,
- SparkListenerApplicationStart("old1", Some("old-app-complete"), 2L, "test", None),
- SparkListenerApplicationEnd(3L)
- )
-
- // Check for logs so that we force the older unfinished app to be loaded, to make
- // sure unfinished apps are also sorted correctly.
- provider.checkForLogs()
-
- // Write an unfinished app, old-style.
- val oldAppIncomplete = writeOldLog("old2", "1.0", None, false,
- SparkListenerApplicationStart("old2", None, 2L, "test", None)
- )
-
- // Force a reload of data from the log directory, and check that both logs are loaded.
+ // Force a reload of data from the log directory, and check that logs are loaded.
// Take the opportunity to check that the offset checks work as expected.
updateAndCheck(provider) { list =>
- list.size should be (5)
- list.count(_.attempts.head.completed) should be (3)
+ list.size should be (3)
+ list.count(_.attempts.head.completed) should be (2)
def makeAppInfo(
id: String,
@@ -132,11 +117,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
newAppComplete.lastModified(), "test", true))
list(1) should be (makeAppInfo("new-complete-lzf", newAppCompressedComplete.getName(),
1L, 4L, newAppCompressedComplete.lastModified(), "test", true))
- list(2) should be (makeAppInfo("old-app-complete", oldAppComplete.getName(), 2L, 3L,
- oldAppComplete.lastModified(), "test", true))
- list(3) should be (makeAppInfo(oldAppIncomplete.getName(), oldAppIncomplete.getName(), 2L,
- -1L, oldAppIncomplete.lastModified(), "test", false))
- list(4) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L,
+ list(2) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L,
newAppIncomplete.lastModified(), "test", false))
// Make sure the UI can be rendered.
@@ -148,38 +129,6 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
}
}
- test("Parse legacy logs with compression codec set") {
- val provider = new FsHistoryProvider(createTestConf())
- val testCodecs = List((classOf[LZFCompressionCodec].getName(), true),
- (classOf[SnappyCompressionCodec].getName(), true),
- ("invalid.codec", false))
-
- testCodecs.foreach { case (codecName, valid) =>
- val codec = if (valid) CompressionCodec.createCodec(new SparkConf(), codecName) else null
- val logDir = new File(testDir, codecName)
- logDir.mkdir()
- createEmptyFile(new File(logDir, SPARK_VERSION_PREFIX + "1.0"))
- writeFile(new File(logDir, LOG_PREFIX + "1"), false, Option(codec),
- SparkListenerApplicationStart("app2", None, 2L, "test", None),
- SparkListenerApplicationEnd(3L)
- )
- createEmptyFile(new File(logDir, COMPRESSION_CODEC_PREFIX + codecName))
-
- val logPath = new Path(logDir.getAbsolutePath())
- try {
- val logInput = provider.openLegacyEventLog(logPath)
- try {
- Source.fromInputStream(logInput).getLines().toSeq.size should be (2)
- } finally {
- logInput.close()
- }
- } catch {
- case e: IllegalArgumentException =>
- valid should be (false)
- }
- }
- }
-
test("SPARK-3697: ignore directories that cannot be read.") {
val logFile1 = newLogFile("new1", None, inProgress = false)
writeFile(logFile1, true, None,
@@ -395,21 +344,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
SparkListenerLogStart("1.4")
)
- // Write a 1.2 log file with no start event (= no app id), it should be ignored.
- writeOldLog("v12Log", "1.2", None, false)
-
- // Write 1.0 and 1.1 logs, which don't have app ids.
- writeOldLog("v11Log", "1.1", None, true,
- SparkListenerApplicationStart("v11Log", None, 2L, "test", None),
- SparkListenerApplicationEnd(3L))
- writeOldLog("v10Log", "1.0", None, true,
- SparkListenerApplicationStart("v10Log", None, 2L, "test", None),
- SparkListenerApplicationEnd(4L))
-
updateAndCheck(provider) { list =>
- list.size should be (2)
- list(0).id should be ("v10Log")
- list(1).id should be ("v11Log")
+ list.size should be (0)
}
}
@@ -499,25 +435,6 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
new SparkConf().set("spark.history.fs.logDirectory", testDir.getAbsolutePath())
}
- private def writeOldLog(
- fname: String,
- sparkVersion: String,
- codec: Option[CompressionCodec],
- completed: Boolean,
- events: SparkListenerEvent*): File = {
- val log = new File(testDir, fname)
- log.mkdir()
-
- val oldEventLog = new File(log, LOG_PREFIX + "1")
- createEmptyFile(new File(log, SPARK_VERSION_PREFIX + sparkVersion))
- writeFile(new File(log, LOG_PREFIX + "1"), false, codec, events: _*)
- if (completed) {
- createEmptyFile(new File(log, APPLICATION_COMPLETE))
- }
-
- log
- }
-
private class SafeModeTestProvider(conf: SparkConf, clock: Clock)
extends FsHistoryProvider(conf, clock) {
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
index 18659fc0c18de..40d0076eecfc8 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
@@ -139,7 +139,24 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
code should be (HttpServletResponse.SC_OK)
jsonOpt should be ('defined)
errOpt should be (None)
- val json = jsonOpt.get
+ val jsonOrg = jsonOpt.get
+
+ // SPARK-10873 added the lastUpdated field for each application's attempt,
+ // the REST API returns the last modified time of EVENT LOG file for this field.
+ // It is not applicable to hard-code this dynamic field in a static expected file,
+ // so here we skip checking the lastUpdated field's value (setting it as "").
+ val json = if (jsonOrg.indexOf("lastUpdated") >= 0) {
+ val subStrings = jsonOrg.split(",")
+ for (i <- subStrings.indices) {
+ if (subStrings(i).indexOf("lastUpdated") >= 0) {
+ subStrings(i) = "\"lastUpdated\":\"\""
+ }
+ }
+ subStrings.mkString(",")
+ } else {
+ jsonOrg
+ }
+
val exp = IOUtils.toString(new FileInputStream(
new File(expRoot, HistoryServerSuite.sanitizePath(name) + "_expectation.json")))
// compare the ASTs so formatting differences don't cause failures
@@ -159,18 +176,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
(1 to 2).foreach { attemptId => doDownloadTest("local-1430917381535", Some(attemptId)) }
}
- test("download legacy logs - all attempts") {
- doDownloadTest("local-1426533911241", None, legacy = true)
- }
-
- test("download legacy logs - single attempts") {
- (1 to 2). foreach {
- attemptId => doDownloadTest("local-1426533911241", Some(attemptId), legacy = true)
- }
- }
-
// Test that the files are downloaded correctly, and validate them.
- def doDownloadTest(appId: String, attemptId: Option[Int], legacy: Boolean = false): Unit = {
+ def doDownloadTest(appId: String, attemptId: Option[Int]): Unit = {
val url = attemptId match {
case Some(id) =>
@@ -188,22 +195,13 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
var entry = zipStream.getNextEntry
entry should not be null
val totalFiles = {
- if (legacy) {
- attemptId.map { x => 3 }.getOrElse(6)
- } else {
- attemptId.map { x => 1 }.getOrElse(2)
- }
+ attemptId.map { x => 1 }.getOrElse(2)
}
var filesCompared = 0
while (entry != null) {
if (!entry.isDirectory) {
val expectedFile = {
- if (legacy) {
- val splits = entry.getName.split("/")
- new File(new File(logDir, splits(0)), splits(1))
- } else {
- new File(logDir, entry.getName)
- }
+ new File(logDir, entry.getName)
}
val expected = Files.toString(expectedFile, Charsets.UTF_8)
val actual = new String(ByteStreams.toByteArray(zipStream), Charsets.UTF_8)
@@ -241,30 +239,6 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
getContentAndCode("foobar")._1 should be (HttpServletResponse.SC_NOT_FOUND)
}
- test("generate history page with relative links") {
- val historyServer = mock[HistoryServer]
- val request = mock[HttpServletRequest]
- val ui = mock[SparkUI]
- val link = "/history/app1"
- val info = new ApplicationHistoryInfo("app1", "app1",
- List(ApplicationAttemptInfo(None, 0, 2, 1, "xxx", true)))
- when(historyServer.getApplicationList()).thenReturn(Seq(info))
- when(ui.basePath).thenReturn(link)
- when(historyServer.getProviderConfig()).thenReturn(Map[String, String]())
- val page = new HistoryPage(historyServer)
-
- // when
- val response = page.render(request)
-
- // then
- val links = response \\ "a"
- val justHrefs = for {
- l <- links
- attrs <- l.attribute("href")
- } yield (attrs.toString)
- justHrefs should contain (UIUtils.prependBaseUri(resource = link))
- }
-
test("relative links are prefixed with uiRoot (spark.ui.proxyBase)") {
val proxyBaseBeforeTest = System.getProperty("spark.ui.proxyBase")
val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase")
diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
index 6ee426e1c9a5f..3a1a67cdc001a 100644
--- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
@@ -551,8 +551,6 @@ private[spark] object TaskMetricsSuite extends Assertions {
* Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the
* info as an accumulator update.
*/
- def makeInfo(a: Accumulable[_, _]): AccumulableInfo = {
- new AccumulableInfo(a.id, a.name, Some(a.value), None, a.isInternal, a.countFailedValues)
- }
+ def makeInfo(a: Accumulable[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None)
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 6f4eda8b47dde..22048003882dd 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -20,9 +20,10 @@ package org.apache.spark.rpc
import java.io.{File, NotSerializableException}
import java.nio.charset.StandardCharsets.UTF_8
import java.util.UUID
-import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit}
+import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeoutException, TimeUnit}
import scala.collection.mutable
+import scala.collection.JavaConverters._
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps
@@ -490,30 +491,30 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
/**
* Setup an [[RpcEndpoint]] to collect all network events.
- * @return the [[RpcEndpointRef]] and an `Seq` that contains network events.
+ * @return the [[RpcEndpointRef]] and an `ConcurrentLinkedQueue` that contains network events.
*/
private def setupNetworkEndpoint(
_env: RpcEnv,
- name: String): (RpcEndpointRef, Seq[(Any, Any)]) = {
- val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)]
+ name: String): (RpcEndpointRef, ConcurrentLinkedQueue[(Any, Any)]) = {
+ val events = new ConcurrentLinkedQueue[(Any, Any)]
val ref = _env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint {
override val rpcEnv = _env
override def receive: PartialFunction[Any, Unit] = {
case "hello" =>
- case m => events += "receive" -> m
+ case m => events.add("receive" -> m)
}
override def onConnected(remoteAddress: RpcAddress): Unit = {
- events += "onConnected" -> remoteAddress
+ events.add("onConnected" -> remoteAddress)
}
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
- events += "onDisconnected" -> remoteAddress
+ events.add("onDisconnected" -> remoteAddress)
}
override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
- events += "onNetworkError" -> remoteAddress
+ events.add("onNetworkError" -> remoteAddress)
}
})
@@ -560,7 +561,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
eventually(timeout(5 seconds), interval(5 millis)) {
// We don't know the exact client address but at least we can verify the message type
- assert(events.map(_._1).contains("onConnected"))
+ assert(events.asScala.map(_._1).exists(_ == "onConnected"))
}
clientEnv.shutdown()
@@ -568,8 +569,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
eventually(timeout(5 seconds), interval(5 millis)) {
// We don't know the exact client address but at least we can verify the message type
- assert(events.map(_._1).contains("onConnected"))
- assert(events.map(_._1).contains("onDisconnected"))
+ assert(events.asScala.map(_._1).exists(_ == "onConnected"))
+ assert(events.asScala.map(_._1).exists(_ == "onDisconnected"))
}
} finally {
clientEnv.shutdown()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index d9c71ec2eae7b..62972a0738211 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -1581,12 +1581,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
assert(Accumulators.get(acc1.id).isDefined)
assert(Accumulators.get(acc2.id).isDefined)
assert(Accumulators.get(acc3.id).isDefined)
- val accInfo1 = new AccumulableInfo(
- acc1.id, acc1.name, Some(15L), None, internal = false, countFailedValues = false)
- val accInfo2 = new AccumulableInfo(
- acc2.id, acc2.name, Some(13L), None, internal = false, countFailedValues = false)
- val accInfo3 = new AccumulableInfo(
- acc3.id, acc3.name, Some(18L), None, internal = false, countFailedValues = false)
+ val accInfo1 = acc1.toInfo(Some(15L), None)
+ val accInfo2 = acc2.toInfo(Some(13L), None)
+ val accInfo3 = acc3.toInfo(Some(18L), None)
val accumUpdates = Seq(accInfo1, accInfo2, accInfo3)
val exceptionFailure = new ExceptionFailure(new SparkException("fondue?"), accumUpdates)
submit(new MyRDD(sc, 1, Nil), Array(0))
@@ -1954,10 +1951,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
extraAccumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo],
taskInfo: TaskInfo = createFakeTaskInfo()): CompletionEvent = {
val accumUpdates = reason match {
- case Success =>
- task.initialAccumulators.map { a =>
- new AccumulableInfo(a.id, a.name, Some(a.zero), None, a.isInternal, a.countFailedValues)
- }
+ case Success => task.initialAccumulators.map { a => a.toInfo(Some(a.zero), None) }
case ef: ExceptionFailure => ef.accumUpdates
case _ => Seq.empty[AccumulableInfo]
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index a2e74365641a6..2c99dd5afb32e 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -165,9 +165,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
val taskSet = FakeTask.createTaskSet(1)
val clock = new ManualClock
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
- val accumUpdates = taskSet.tasks.head.initialAccumulators.map { a =>
- new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues)
- }
+ val accumUpdates = taskSet.tasks.head.initialAccumulators.map { a => a.toInfo(Some(0L), None) }
// Offer a host with NO_PREF as the constraint,
// we should get a nopref task immediately since that's what we only have
@@ -186,9 +184,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
val taskSet = FakeTask.createTaskSet(3)
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
val accumUpdatesByTask: Array[Seq[AccumulableInfo]] = taskSet.tasks.map { task =>
- task.initialAccumulators.map { a =>
- new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues)
- }
+ task.initialAccumulators.map { a => a.toInfo(Some(0L), None) }
}
// First three offers should all find tasks
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 828153bdbfc44..c9c2fb2691d70 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -21,7 +21,7 @@ import java.io.InputStream
import java.util.concurrent.Semaphore
import scala.concurrent.ExecutionContext.Implicits.global
-import scala.concurrent.future
+import scala.concurrent.Future
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito._
@@ -149,7 +149,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
- future {
+ Future {
// Return the first two blocks, and wait till task completion before returning the 3rd one
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
@@ -211,7 +211,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
- future {
+ Future {
// Return the first block, and then fail.
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
index b207d497f33c2..6f7dddd4f760a 100644
--- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.util
-import java.util.concurrent.CountDownLatch
+import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch}
-import scala.collection.mutable
+import scala.collection.JavaConverters._
import scala.concurrent.duration._
import scala.language.postfixOps
@@ -31,11 +31,11 @@ import org.apache.spark.SparkFunSuite
class EventLoopSuite extends SparkFunSuite with Timeouts {
test("EventLoop") {
- val buffer = new mutable.ArrayBuffer[Int] with mutable.SynchronizedBuffer[Int]
+ val buffer = new ConcurrentLinkedQueue[Int]
val eventLoop = new EventLoop[Int]("test") {
override def onReceive(event: Int): Unit = {
- buffer += event
+ buffer.add(event)
}
override def onError(e: Throwable): Unit = {}
@@ -43,7 +43,7 @@ class EventLoopSuite extends SparkFunSuite with Timeouts {
eventLoop.start()
(1 to 100).foreach(eventLoop.post)
eventually(timeout(5 seconds), interval(5 millis)) {
- assert((1 to 100) === buffer.toSeq)
+ assert((1 to 100) === buffer.asScala.toSeq)
}
eventLoop.stop()
}
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 1345881a2aea3..de6f408fa82be 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -374,15 +374,18 @@ class JsonProtocolSuite extends SparkFunSuite {
test("AccumulableInfo backward compatibility") {
// "Internal" property of AccumulableInfo was added in 1.5.1
val accumulableInfo = makeAccumulableInfo(1, internal = true, countFailedValues = true)
- val oldJson = JsonProtocol.accumulableInfoToJson(accumulableInfo)
- .removeField({ _._1 == "Internal" })
+ val accumulableInfoJson = JsonProtocol.accumulableInfoToJson(accumulableInfo)
+ val oldJson = accumulableInfoJson.removeField({ _._1 == "Internal" })
val oldInfo = JsonProtocol.accumulableInfoFromJson(oldJson)
assert(!oldInfo.internal)
// "Count Failed Values" property of AccumulableInfo was added in 2.0.0
- val oldJson2 = JsonProtocol.accumulableInfoToJson(accumulableInfo)
- .removeField({ _._1 == "Count Failed Values" })
+ val oldJson2 = accumulableInfoJson.removeField({ _._1 == "Count Failed Values" })
val oldInfo2 = JsonProtocol.accumulableInfoFromJson(oldJson2)
assert(!oldInfo2.countFailedValues)
+ // "Metadata" property of AccumulableInfo was added in 2.0.0
+ val oldJson3 = accumulableInfoJson.removeField({ _._1 == "Metadata" })
+ val oldInfo3 = JsonProtocol.accumulableInfoFromJson(oldJson3)
+ assert(oldInfo3.metadata.isEmpty)
}
test("ExceptionFailure backward compatibility: accumulator updates") {
@@ -820,9 +823,10 @@ private[spark] object JsonProtocolSuite extends Assertions {
private def makeAccumulableInfo(
id: Int,
internal: Boolean = false,
- countFailedValues: Boolean = false): AccumulableInfo =
+ countFailedValues: Boolean = false,
+ metadata: Option[String] = None): AccumulableInfo =
new AccumulableInfo(id, Some(s"Accumulable$id"), Some(s"delta$id"), Some(s"val$id"),
- internal, countFailedValues)
+ internal, countFailedValues, metadata)
/**
* Creates a TaskMetrics object describing a task that read data from Hadoop (if hasHadoopInput is
diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh
index 00bf81120df65..2fd7fcc39ea28 100755
--- a/dev/create-release/release-build.sh
+++ b/dev/create-release/release-build.sh
@@ -134,9 +134,9 @@ if [[ "$1" == "package" ]]; then
cd spark-$SPARK_VERSION-bin-$NAME
- # TODO There should probably be a flag to make-distribution to allow 2.11 support
- if [[ $FLAGS == *scala-2.11* ]]; then
- ./dev/change-scala-version.sh 2.11
+ # TODO There should probably be a flag to make-distribution to allow 2.10 support
+ if [[ $FLAGS == *scala-2.10* ]]; then
+ ./dev/change-scala-version.sh 2.10
fi
export ZINC_PORT=$ZINC_PORT
@@ -228,8 +228,8 @@ if [[ "$1" == "publish-snapshot" ]]; then
$MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES \
-Phive-thriftserver deploy
- ./dev/change-scala-version.sh 2.11
- $MVN -DzincPort=$ZINC_PORT -Dscala-2.11 --settings $tmp_settings \
+ ./dev/change-scala-version.sh 2.10
+ $MVN -DzincPort=$ZINC_PORT -Dscala-2.10 --settings $tmp_settings \
-DskipTests $PUBLISH_PROFILES clean deploy
# Clean-up Zinc nailgun process
@@ -266,9 +266,9 @@ if [[ "$1" == "publish-release" ]]; then
$MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $PUBLISH_PROFILES \
-Phive-thriftserver clean install
- ./dev/change-scala-version.sh 2.11
+ ./dev/change-scala-version.sh 2.10
- $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Dscala-2.11 \
+ $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Dscala-2.10 \
-DskipTests $PUBLISH_PROFILES clean install
# Clean-up Zinc nailgun process
diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2
index 4d9937c5cbc34..3a14499d9b4d9 100644
--- a/dev/deps/spark-deps-hadoop-2.2
+++ b/dev/deps/spark-deps-hadoop-2.2
@@ -14,13 +14,13 @@ avro-ipc-1.7.7-tests.jar
avro-ipc-1.7.7.jar
avro-mapred-1.7.7-hadoop2.jar
bonecp-0.8.0.RELEASE.jar
-breeze-macros_2.10-0.11.2.jar
-breeze_2.10-0.11.2.jar
+breeze-macros_2.11-0.11.2.jar
+breeze_2.11-0.11.2.jar
calcite-avatica-1.2.0-incubating.jar
calcite-core-1.2.0-incubating.jar
calcite-linq4j-1.2.0-incubating.jar
chill-java-0.5.0.jar
-chill_2.10-0.5.0.jar
+chill_2.11-0.5.0.jar
commons-beanutils-1.7.0.jar
commons-beanutils-core-1.8.0.jar
commons-cli-1.2.jar
@@ -86,10 +86,9 @@ jackson-core-asl-1.9.13.jar
jackson-databind-2.5.3.jar
jackson-jaxrs-1.9.13.jar
jackson-mapper-asl-1.9.13.jar
-jackson-module-scala_2.10-2.5.3.jar
+jackson-module-scala_2.11-2.5.3.jar
jackson-xc-1.9.13.jar
janino-2.7.8.jar
-jansi-1.4.jar
javax.inject-1.jar
javax.servlet-3.0.0.v201112011016.jar
javax.servlet-3.1.jar
@@ -111,15 +110,14 @@ jets3t-0.7.1.jar
jettison-1.1.jar
jetty-all-7.6.0.v20120127.jar
jetty-util-6.1.26.jar
-jline-2.10.5.jar
jline-2.12.jar
joda-time-2.9.jar
jodd-core-3.5.2.jar
jpam-1.1.jar
json-20090211.jar
-json4s-ast_2.10-3.2.10.jar
-json4s-core_2.10-3.2.10.jar
-json4s-jackson_2.10-3.2.10.jar
+json4s-ast_2.11-3.2.10.jar
+json4s-core_2.11-3.2.10.jar
+json4s-jackson_2.11-3.2.10.jar
jsr305-1.3.9.jar
jta-1.1.jar
jtransforms-2.4.0.jar
@@ -158,19 +156,20 @@ pmml-schema-1.2.7.jar
protobuf-java-2.5.0.jar
py4j-0.9.1.jar
pyrolite-4.9.jar
-quasiquotes_2.10-2.0.0-M8.jar
reflectasm-1.07-shaded.jar
-scala-compiler-2.10.5.jar
-scala-library-2.10.5.jar
-scala-reflect-2.10.5.jar
-scalap-2.10.5.jar
+scala-compiler-2.11.7.jar
+scala-library-2.11.7.jar
+scala-parser-combinators_2.11-1.0.4.jar
+scala-reflect-2.11.7.jar
+scala-xml_2.11-1.0.2.jar
+scalap-2.11.7.jar
servlet-api-2.5.jar
slf4j-api-1.7.10.jar
slf4j-log4j12-1.7.10.jar
snappy-0.2.jar
snappy-java-1.1.2.jar
-spire-macros_2.10-0.7.4.jar
-spire_2.10-0.7.4.jar
+spire-macros_2.11-0.7.4.jar
+spire_2.11-0.7.4.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3
index fd659ee20df1a..615836b3d3b77 100644
--- a/dev/deps/spark-deps-hadoop-2.3
+++ b/dev/deps/spark-deps-hadoop-2.3
@@ -16,13 +16,13 @@ avro-mapred-1.7.7-hadoop2.jar
base64-2.3.8.jar
bcprov-jdk15on-1.51.jar
bonecp-0.8.0.RELEASE.jar
-breeze-macros_2.10-0.11.2.jar
-breeze_2.10-0.11.2.jar
+breeze-macros_2.11-0.11.2.jar
+breeze_2.11-0.11.2.jar
calcite-avatica-1.2.0-incubating.jar
calcite-core-1.2.0-incubating.jar
calcite-linq4j-1.2.0-incubating.jar
chill-java-0.5.0.jar
-chill_2.10-0.5.0.jar
+chill_2.11-0.5.0.jar
commons-beanutils-1.7.0.jar
commons-beanutils-core-1.8.0.jar
commons-cli-1.2.jar
@@ -81,10 +81,9 @@ jackson-core-asl-1.9.13.jar
jackson-databind-2.5.3.jar
jackson-jaxrs-1.9.13.jar
jackson-mapper-asl-1.9.13.jar
-jackson-module-scala_2.10-2.5.3.jar
+jackson-module-scala_2.11-2.5.3.jar
jackson-xc-1.9.13.jar
janino-2.7.8.jar
-jansi-1.4.jar
java-xmlbuilder-1.0.jar
javax.inject-1.jar
javax.servlet-3.0.0.v201112011016.jar
@@ -102,15 +101,14 @@ jettison-1.1.jar
jetty-6.1.26.jar
jetty-all-7.6.0.v20120127.jar
jetty-util-6.1.26.jar
-jline-2.10.5.jar
jline-2.12.jar
joda-time-2.9.jar
jodd-core-3.5.2.jar
jpam-1.1.jar
json-20090211.jar
-json4s-ast_2.10-3.2.10.jar
-json4s-core_2.10-3.2.10.jar
-json4s-jackson_2.10-3.2.10.jar
+json4s-ast_2.11-3.2.10.jar
+json4s-core_2.11-3.2.10.jar
+json4s-jackson_2.11-3.2.10.jar
jsr305-1.3.9.jar
jta-1.1.jar
jtransforms-2.4.0.jar
@@ -149,19 +147,20 @@ pmml-schema-1.2.7.jar
protobuf-java-2.5.0.jar
py4j-0.9.1.jar
pyrolite-4.9.jar
-quasiquotes_2.10-2.0.0-M8.jar
reflectasm-1.07-shaded.jar
-scala-compiler-2.10.5.jar
-scala-library-2.10.5.jar
-scala-reflect-2.10.5.jar
-scalap-2.10.5.jar
+scala-compiler-2.11.7.jar
+scala-library-2.11.7.jar
+scala-parser-combinators_2.11-1.0.4.jar
+scala-reflect-2.11.7.jar
+scala-xml_2.11-1.0.2.jar
+scalap-2.11.7.jar
servlet-api-2.5.jar
slf4j-api-1.7.10.jar
slf4j-log4j12-1.7.10.jar
snappy-0.2.jar
snappy-java-1.1.2.jar
-spire-macros_2.10-0.7.4.jar
-spire_2.10-0.7.4.jar
+spire-macros_2.11-0.7.4.jar
+spire_2.11-0.7.4.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4
index afae3deb9ada2..f275226f1d088 100644
--- a/dev/deps/spark-deps-hadoop-2.4
+++ b/dev/deps/spark-deps-hadoop-2.4
@@ -16,13 +16,13 @@ avro-mapred-1.7.7-hadoop2.jar
base64-2.3.8.jar
bcprov-jdk15on-1.51.jar
bonecp-0.8.0.RELEASE.jar
-breeze-macros_2.10-0.11.2.jar
-breeze_2.10-0.11.2.jar
+breeze-macros_2.11-0.11.2.jar
+breeze_2.11-0.11.2.jar
calcite-avatica-1.2.0-incubating.jar
calcite-core-1.2.0-incubating.jar
calcite-linq4j-1.2.0-incubating.jar
chill-java-0.5.0.jar
-chill_2.10-0.5.0.jar
+chill_2.11-0.5.0.jar
commons-beanutils-1.7.0.jar
commons-beanutils-core-1.8.0.jar
commons-cli-1.2.jar
@@ -81,10 +81,9 @@ jackson-core-asl-1.9.13.jar
jackson-databind-2.5.3.jar
jackson-jaxrs-1.9.13.jar
jackson-mapper-asl-1.9.13.jar
-jackson-module-scala_2.10-2.5.3.jar
+jackson-module-scala_2.11-2.5.3.jar
jackson-xc-1.9.13.jar
janino-2.7.8.jar
-jansi-1.4.jar
java-xmlbuilder-1.0.jar
javax.inject-1.jar
javax.servlet-3.0.0.v201112011016.jar
@@ -103,15 +102,14 @@ jettison-1.1.jar
jetty-6.1.26.jar
jetty-all-7.6.0.v20120127.jar
jetty-util-6.1.26.jar
-jline-2.10.5.jar
jline-2.12.jar
joda-time-2.9.jar
jodd-core-3.5.2.jar
jpam-1.1.jar
json-20090211.jar
-json4s-ast_2.10-3.2.10.jar
-json4s-core_2.10-3.2.10.jar
-json4s-jackson_2.10-3.2.10.jar
+json4s-ast_2.11-3.2.10.jar
+json4s-core_2.11-3.2.10.jar
+json4s-jackson_2.11-3.2.10.jar
jsr305-1.3.9.jar
jta-1.1.jar
jtransforms-2.4.0.jar
@@ -150,19 +148,20 @@ pmml-schema-1.2.7.jar
protobuf-java-2.5.0.jar
py4j-0.9.1.jar
pyrolite-4.9.jar
-quasiquotes_2.10-2.0.0-M8.jar
reflectasm-1.07-shaded.jar
-scala-compiler-2.10.5.jar
-scala-library-2.10.5.jar
-scala-reflect-2.10.5.jar
-scalap-2.10.5.jar
+scala-compiler-2.11.7.jar
+scala-library-2.11.7.jar
+scala-parser-combinators_2.11-1.0.4.jar
+scala-reflect-2.11.7.jar
+scala-xml_2.11-1.0.2.jar
+scalap-2.11.7.jar
servlet-api-2.5.jar
slf4j-api-1.7.10.jar
slf4j-log4j12-1.7.10.jar
snappy-0.2.jar
snappy-java-1.1.2.jar
-spire-macros_2.10-0.7.4.jar
-spire_2.10-0.7.4.jar
+spire-macros_2.11-0.7.4.jar
+spire_2.11-0.7.4.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index 5a6460136a3a0..21432a16e3659 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -20,13 +20,13 @@ avro-mapred-1.7.7-hadoop2.jar
base64-2.3.8.jar
bcprov-jdk15on-1.51.jar
bonecp-0.8.0.RELEASE.jar
-breeze-macros_2.10-0.11.2.jar
-breeze_2.10-0.11.2.jar
+breeze-macros_2.11-0.11.2.jar
+breeze_2.11-0.11.2.jar
calcite-avatica-1.2.0-incubating.jar
calcite-core-1.2.0-incubating.jar
calcite-linq4j-1.2.0-incubating.jar
chill-java-0.5.0.jar
-chill_2.10-0.5.0.jar
+chill_2.11-0.5.0.jar
commons-beanutils-1.7.0.jar
commons-beanutils-core-1.8.0.jar
commons-cli-1.2.jar
@@ -87,10 +87,9 @@ jackson-core-asl-1.9.13.jar
jackson-databind-2.5.3.jar
jackson-jaxrs-1.9.13.jar
jackson-mapper-asl-1.9.13.jar
-jackson-module-scala_2.10-2.5.3.jar
+jackson-module-scala_2.11-2.5.3.jar
jackson-xc-1.9.13.jar
janino-2.7.8.jar
-jansi-1.4.jar
java-xmlbuilder-1.0.jar
javax.inject-1.jar
javax.servlet-3.0.0.v201112011016.jar
@@ -109,15 +108,14 @@ jettison-1.1.jar
jetty-6.1.26.jar
jetty-all-7.6.0.v20120127.jar
jetty-util-6.1.26.jar
-jline-2.10.5.jar
jline-2.12.jar
joda-time-2.9.jar
jodd-core-3.5.2.jar
jpam-1.1.jar
json-20090211.jar
-json4s-ast_2.10-3.2.10.jar
-json4s-core_2.10-3.2.10.jar
-json4s-jackson_2.10-3.2.10.jar
+json4s-ast_2.11-3.2.10.jar
+json4s-core_2.11-3.2.10.jar
+json4s-jackson_2.11-3.2.10.jar
jsr305-1.3.9.jar
jta-1.1.jar
jtransforms-2.4.0.jar
@@ -156,19 +154,20 @@ pmml-schema-1.2.7.jar
protobuf-java-2.5.0.jar
py4j-0.9.1.jar
pyrolite-4.9.jar
-quasiquotes_2.10-2.0.0-M8.jar
reflectasm-1.07-shaded.jar
-scala-compiler-2.10.5.jar
-scala-library-2.10.5.jar
-scala-reflect-2.10.5.jar
-scalap-2.10.5.jar
+scala-compiler-2.11.7.jar
+scala-library-2.11.7.jar
+scala-parser-combinators_2.11-1.0.4.jar
+scala-reflect-2.11.7.jar
+scala-xml_2.11-1.0.2.jar
+scalap-2.11.7.jar
servlet-api-2.5.jar
slf4j-api-1.7.10.jar
slf4j-log4j12-1.7.10.jar
snappy-0.2.jar
snappy-java-1.1.2.jar
-spire-macros_2.10-0.7.4.jar
-spire_2.10-0.7.4.jar
+spire-macros_2.11-0.7.4.jar
+spire_2.11-0.7.4.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index 70083e7f3d16a..20e09cd002635 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -20,13 +20,13 @@ avro-mapred-1.7.7-hadoop2.jar
base64-2.3.8.jar
bcprov-jdk15on-1.51.jar
bonecp-0.8.0.RELEASE.jar
-breeze-macros_2.10-0.11.2.jar
-breeze_2.10-0.11.2.jar
+breeze-macros_2.11-0.11.2.jar
+breeze_2.11-0.11.2.jar
calcite-avatica-1.2.0-incubating.jar
calcite-core-1.2.0-incubating.jar
calcite-linq4j-1.2.0-incubating.jar
chill-java-0.5.0.jar
-chill_2.10-0.5.0.jar
+chill_2.11-0.5.0.jar
commons-beanutils-1.7.0.jar
commons-beanutils-core-1.8.0.jar
commons-cli-1.2.jar
@@ -87,10 +87,9 @@ jackson-core-asl-1.9.13.jar
jackson-databind-2.5.3.jar
jackson-jaxrs-1.9.13.jar
jackson-mapper-asl-1.9.13.jar
-jackson-module-scala_2.10-2.5.3.jar
+jackson-module-scala_2.11-2.5.3.jar
jackson-xc-1.9.13.jar
janino-2.7.8.jar
-jansi-1.4.jar
java-xmlbuilder-1.0.jar
javax.inject-1.jar
javax.servlet-3.0.0.v201112011016.jar
@@ -109,15 +108,14 @@ jettison-1.1.jar
jetty-6.1.26.jar
jetty-all-7.6.0.v20120127.jar
jetty-util-6.1.26.jar
-jline-2.10.5.jar
jline-2.12.jar
joda-time-2.9.jar
jodd-core-3.5.2.jar
jpam-1.1.jar
json-20090211.jar
-json4s-ast_2.10-3.2.10.jar
-json4s-core_2.10-3.2.10.jar
-json4s-jackson_2.10-3.2.10.jar
+json4s-ast_2.11-3.2.10.jar
+json4s-core_2.11-3.2.10.jar
+json4s-jackson_2.11-3.2.10.jar
jsp-api-2.1.jar
jsr305-1.3.9.jar
jta-1.1.jar
@@ -157,19 +155,20 @@ pmml-schema-1.2.7.jar
protobuf-java-2.5.0.jar
py4j-0.9.1.jar
pyrolite-4.9.jar
-quasiquotes_2.10-2.0.0-M8.jar
reflectasm-1.07-shaded.jar
-scala-compiler-2.10.5.jar
-scala-library-2.10.5.jar
-scala-reflect-2.10.5.jar
-scalap-2.10.5.jar
+scala-compiler-2.11.7.jar
+scala-library-2.11.7.jar
+scala-parser-combinators_2.11-1.0.4.jar
+scala-reflect-2.11.7.jar
+scala-xml_2.11-1.0.2.jar
+scalap-2.11.7.jar
servlet-api-2.5.jar
slf4j-api-1.7.10.jar
slf4j-log4j12-1.7.10.jar
snappy-0.2.jar
snappy-java-1.1.2.jar
-spire-macros_2.10-0.7.4.jar
-spire_2.10-0.7.4.jar
+spire-macros_2.11-0.7.4.jar
+spire_2.11-0.7.4.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
diff --git a/docker-integration-tests/pom.xml b/docker-integration-tests/pom.xml
index 78b638ecfa638..833ca29cd8218 100644
--- a/docker-integration-tests/pom.xml
+++ b/docker-integration-tests/pom.xml
@@ -21,12 +21,12 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../pom.xml
- spark-docker-integration-tests_2.10
+ spark-docker-integration-tests_2.11jarSpark Project Docker Integration Testshttp://spark.apache.org/
diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
index 7d011be37067b..72bda8fe1ef10 100644
--- a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
+++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
@@ -21,7 +21,7 @@ import java.sql.Connection
import java.util.Properties
import org.apache.spark.sql.Column
-import org.apache.spark.sql.catalyst.expressions.{If, Literal}
+import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.tags.DockerTest
@DockerTest
@@ -39,12 +39,13 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
override def dataPreparation(conn: Connection): Unit = {
conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
conn.setCatalog("foo")
+ conn.prepareStatement("CREATE TYPE enum_type AS ENUM ('d1', 'd2')").executeUpdate()
conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, "
+ "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, "
- + "c10 integer[], c11 text[], c12 real[])").executeUpdate()
+ + "c10 integer[], c11 text[], c12 real[], c13 enum_type)").executeUpdate()
conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', "
+ "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', "
- + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}')""").executeUpdate()
+ + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}', 'd1')""").executeUpdate()
}
test("Type mapping for various types") {
@@ -52,7 +53,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
val rows = df.collect()
assert(rows.length == 1)
val types = rows(0).toSeq.map(x => x.getClass)
- assert(types.length == 13)
+ assert(types.length == 14)
assert(classOf[String].isAssignableFrom(types(0)))
assert(classOf[java.lang.Integer].isAssignableFrom(types(1)))
assert(classOf[java.lang.Double].isAssignableFrom(types(2)))
@@ -66,22 +67,24 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(classOf[Seq[Int]].isAssignableFrom(types(10)))
assert(classOf[Seq[String]].isAssignableFrom(types(11)))
assert(classOf[Seq[Double]].isAssignableFrom(types(12)))
+ assert(classOf[String].isAssignableFrom(types(13)))
assert(rows(0).getString(0).equals("hello"))
assert(rows(0).getInt(1) == 42)
assert(rows(0).getDouble(2) == 1.25)
assert(rows(0).getLong(3) == 123456789012345L)
- assert(rows(0).getBoolean(4) == false)
+ assert(!rows(0).getBoolean(4))
// BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's...
assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5),
Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49)))
assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6),
Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte)))
- assert(rows(0).getBoolean(7) == true)
+ assert(rows(0).getBoolean(7))
assert(rows(0).getString(8) == "172.16.0.42")
assert(rows(0).getString(9) == "192.168.0.0/16")
assert(rows(0).getSeq(10) == Seq(1, 2))
assert(rows(0).getSeq(11) == Seq("a", null, "b"))
assert(rows(0).getSeq(12).toSeq == Seq(0.11f, 0.22f))
+ assert(rows(0).getString(13) == "d1")
}
test("Basic write test") {
diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb
index 174c202e37918..f926d67e6beaf 100644
--- a/docs/_plugins/copy_api_dirs.rb
+++ b/docs/_plugins/copy_api_dirs.rb
@@ -37,7 +37,7 @@
# Copy over the unified ScalaDoc for all projects to api/scala.
# This directory will be copied over to _site when `jekyll` command is run.
- source = "../target/scala-2.10/unidoc"
+ source = "../target/scala-2.11/unidoc"
dest = "api/scala"
puts "Making directory " + dest
diff --git a/docs/building-spark.md b/docs/building-spark.md
index e1abcf1be501d..975e1b295c8ae 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -114,13 +114,11 @@ By default Spark will build with Hive 0.13.1 bindings.
mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package
{% endhighlight %}
-# Building for Scala 2.11
-To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` property:
+# Building for Scala 2.10
+To produce a Spark package compiled with Scala 2.10, use the `-Dscala-2.10` property:
- ./dev/change-scala-version.sh 2.11
- mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package
-
-Spark does not yet support its JDBC component for Scala 2.11.
+ ./dev/change-scala-version.sh 2.10
+ mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package
# Spark Tests in Maven
diff --git a/docs/configuration.md b/docs/configuration.md
index 74a8fb5d35a66..cd9dc1bcfc113 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1169,8 +1169,8 @@ Apart from these, the following properties are also available, and may be useful
false
Whether to use dynamic resource allocation, which scales the number of executors registered
- with this application up and down based on the workload. Note that this is currently only
- available on YARN mode. For more detail, see the description
+ with this application up and down based on the workload.
+ For more detail, see the description
here.
This requires spark.shuffle.service.enabled to be set.
@@ -1585,6 +1585,29 @@ Apart from these, the following properties are also available, and may be useful
+#### Deploy
+
+
+
Property Name
Default
Meaniing
+
+
spark.deploy.recoveryMode
+
NONE
+
The recovery mode setting to recover submitted Spark jobs with cluster mode when it failed and relaunches.
+ This is only applicable for cluster mode when running with Standalone or Mesos.
+
+
+
spark.deploy.zookeeper.url
+
None
+
When `spark.deploy.recoveryMode` is set to ZOOKEEPER, this configuration is used to set the zookeeper URL to connect to.
+
+
+
spark.deploy.zookeeper.dir
+
None
+
When `spark.deploy.recoveryMode` is set to ZOOKEEPER, this configuration is used to set the zookeeper directory to store recovery state.
+
+
+
+
#### Cluster Managers
Each cluster manager in Spark has additional configuration options. Configurations
can be found on the pages for each mode:
diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md
index 8ffc997b4bf5a..9569a06472cbf 100644
--- a/docs/ml-classification-regression.md
+++ b/docs/ml-classification-regression.md
@@ -289,7 +289,7 @@ The example below demonstrates how to load the
-Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) for more details.
+Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.OneVsRest) for more details.
{% include_example scala/org/apache/spark/examples/ml/OneVsRestExample.scala %}
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index 5aafd53b584e7..f8279262e673f 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -627,7 +627,7 @@ Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/
The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator)
for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator)
-for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator)
+for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator)
for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetricName`
method in each of these evaluators.
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md
index ed720f1039f94..e1c87a8d95a32 100644
--- a/docs/running-on-mesos.md
+++ b/docs/running-on-mesos.md
@@ -153,7 +153,10 @@ can find the results of the driver from the Mesos Web UI.
To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script,
passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host.
-If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`).
+If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). Note that the `MesosClusterDispatcher` not yet supports multiple instances for HA.
+
+The `MesosClusterDispatcher` also supports writing recovery state into Zookeeper. This will allow the `MesosClusterDispatcher` to be able to recover all submitted and running containers on relaunch. In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations.
+For more information about these configurations please refer to the configurations (doc)[configurations.html#deploy].
From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master URL
to the URL of the `MesosClusterDispatcher` (e.g: mesos://dispatcher:7077). You can view driver statuses on the
@@ -243,18 +246,15 @@ In either case, HDFS runs separately from Hadoop MapReduce, without being schedu
# Dynamic Resource Allocation with Mesos
-Mesos supports dynamic allocation only with coarse grain mode, which can resize the number of executors based on statistics
-of the application. While dynamic allocation supports both scaling up and scaling down the number of executors, the coarse grain scheduler only supports scaling down
-since it is already designed to run one executor per slave with the configured amount of resources. However, after scaling down the number of executors the coarse grain scheduler
-can scale back up to the same amount of executors when Spark signals more executors are needed.
+Mesos supports dynamic allocation only with coarse-grain mode, which can resize the number of
+executors based on statistics of the application. For general information,
+see [Dynamic Resource Allocation](job-scheduling.html#dynamic-resource-allocation).
-Users that like to utilize this feature should launch the Mesos Shuffle Service that
-provides shuffle data cleanup functionality on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's
-termination. To launch/stop the Mesos Shuffle Service please use the provided sbin/start-mesos-shuffle-service.sh and sbin/stop-mesos-shuffle-service.sh
-scripts accordingly.
+The External Shuffle Service to use is the Mesos Shuffle Service. It provides shuffle data cleanup functionality
+on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's
+termination. To launch it, run `$SPARK_HOME/sbin/start-mesos-shuffle-service.sh` on all slave nodes, with `spark.shuffle.service.enabled` set to `true`.
-The Shuffle Service is expected to be running on each slave node that will run Spark executors. One way to easily achieve this with Mesos
-is to launch the Shuffle Service with Marathon with a unique host constraint.
+This can also be achieved through Marathon, using a unique host constraint, and the following command: `bin/spark-class org.apache.spark.deploy.mesos.MesosExternalShuffleService`.
# Configuration
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index 2fe9ec3542b28..3de72bc016dd4 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -112,8 +112,8 @@ You can optionally configure the cluster further by setting environment variable
SPARK_LOCAL_DIRS
- Directory to use for "scratch" space in Spark, including map output files and RDDs that get
- stored on disk. This should be on a fast, local disk in your system. It can also be a
+ Directory to use for "scratch" space in Spark, including map output files and RDDs that get
+ stored on disk. This should be on a fast, local disk in your system. It can also be a
comma-separated list of multiple directories on different disks.
@@ -341,23 +341,8 @@ Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.o
**Configuration**
-In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env using this configuration:
-
-
-
System property
Meaning
-
-
spark.deploy.recoveryMode
-
Set to ZOOKEEPER to enable standby Master recovery mode (default: NONE).
-
-
-
spark.deploy.zookeeper.url
-
The ZooKeeper cluster url (e.g., 192.168.1.100:2181,192.168.1.101:2181).
-
-
-
spark.deploy.zookeeper.dir
-
The directory in ZooKeeper to store recovery state (default: /spark).
-
-
+In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations.
+For more information about these configurations please refer to the configurations (doc)[configurations.html#deploy]
Possible gotcha: If you have multiple Masters in your cluster but fail to correctly configure the Masters to use ZooKeeper, the Masters will fail to discover each other and think they're all leaders. This will not lead to a healthy cluster state (as all Masters will schedule independently).
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index fddc51379406b..550a40010e828 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1695,7 +1695,7 @@ on all of the worker nodes, as they will need access to the Hive serialization a
Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` (for security configuration),
`hdfs-site.xml` (for HDFS configuration) file in `conf/`. Please note when running
-the query on a YARN cluster (`cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory
+the query on a YARN cluster (`cluster` mode), the `datanucleus` jars under the `lib` directory
and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the
YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the
`spark-submit` command.
diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md
index acbb0f298fe47..413532f2f6cfa 100644
--- a/docs/submitting-applications.md
+++ b/docs/submitting-applications.md
@@ -177,8 +177,9 @@ debugging information by running `spark-submit` with the `--verbose` option.
# Advanced Dependency Management
When using `spark-submit`, the application jar along with any jars included with the `--jars` option
-will be automatically transferred to the cluster. Spark uses the following URL scheme to allow
-different strategies for disseminating jars:
+will be automatically transferred to the cluster. URLs supplied after `--jars` must be separated by commas. That list is included on the driver and executor classpaths. Directory expansion does not work with `--jars`.
+
+Spark uses the following URL scheme to allow different strategies for disseminating jars:
- **file:** - Absolute paths and `file:/` URIs are served by the driver's HTTP file server, and
every executor pulls the file from the driver HTTP server.
diff --git a/examples/pom.xml b/examples/pom.xml
index 9437cee2abfdf..82baa9085b4f9 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../pom.xmlorg.apache.spark
- spark-examples_2.10
+ spark-examples_2.11examples
diff --git a/external/akka/pom.xml b/external/akka/pom.xml
index 06c8e8aaabd8c..bbe644e3b32b3 100644
--- a/external/akka/pom.xml
+++ b/external/akka/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-streaming-akka_2.10
+ spark-streaming-akka_2.11streaming-akka
diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml
index b2c377fe4cc9b..ac15b93c048da 100644
--- a/external/flume-assembly/pom.xml
+++ b/external/flume-assembly/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-streaming-flume-assembly_2.10
+ spark-streaming-flume-assembly_2.11jarSpark Project External Flume Assemblyhttp://spark.apache.org/
diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml
index 4b6485ee0a71a..e4effe158c826 100644
--- a/external/flume-sink/pom.xml
+++ b/external/flume-sink/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-streaming-flume-sink_2.10
+ spark-streaming-flume-sink_2.11streaming-flume-sink
diff --git a/external/flume/pom.xml b/external/flume/pom.xml
index a79656c6f7d96..d650dd034d636 100644
--- a/external/flume/pom.xml
+++ b/external/flume/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-streaming-flume_2.10
+ spark-streaming-flume_2.11streaming-flume
diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml
index 0c466b3c4ac37..62818f5e8f434 100644
--- a/external/kafka-assembly/pom.xml
+++ b/external/kafka-assembly/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-streaming-kafka-assembly_2.10
+ spark-streaming-kafka-assembly_2.11jarSpark Project External Kafka Assemblyhttp://spark.apache.org/
diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml
index 5180ab6dbafbd..68d52e9339b3d 100644
--- a/external/kafka/pom.xml
+++ b/external/kafka/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-streaming-kafka_2.10
+ spark-streaming-kafka_2.11streaming-kafka
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
index d7885d7cc1ae1..8a66621a3125c 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
@@ -29,15 +29,19 @@ import kafka.common.{ErrorMapping, OffsetAndMetadata, OffsetMetadataAndError, To
import kafka.consumer.{ConsumerConfig, SimpleConsumer}
import org.apache.spark.SparkException
+import org.apache.spark.annotation.DeveloperApi
/**
+ * :: DeveloperApi ::
* Convenience methods for interacting with a Kafka cluster.
+ * See
+ * A Guide To The Kafka Protocol for more details on individual api calls.
* @param kafkaParams Kafka
* configuration parameters.
* Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s),
* NOT zookeeper servers, specified in host1:port1,host2:port2 form
*/
-private[spark]
+@DeveloperApi
class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable {
import KafkaCluster.{Err, LeaderOffset, SimpleConsumerConfig}
@@ -227,7 +231,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable {
// this 0 here indicates api version, in this case the original ZK backed api.
private def defaultConsumerApiVersion: Short = 0
- /** Requires Kafka >= 0.8.1.1 */
+ /** Requires Kafka >= 0.8.1.1. Defaults to the original ZooKeeper backed api version. */
def getConsumerOffsets(
groupId: String,
topicAndPartitions: Set[TopicAndPartition]
@@ -246,7 +250,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable {
}
}
- /** Requires Kafka >= 0.8.1.1 */
+ /** Requires Kafka >= 0.8.1.1. Defaults to the original ZooKeeper backed api version. */
def getConsumerOffsetMetadata(
groupId: String,
topicAndPartitions: Set[TopicAndPartition]
@@ -283,7 +287,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable {
Left(errs)
}
- /** Requires Kafka >= 0.8.1.1 */
+ /** Requires Kafka >= 0.8.1.1. Defaults to the original ZooKeeper backed api version. */
def setConsumerOffsets(
groupId: String,
offsets: Map[TopicAndPartition, Long]
@@ -301,7 +305,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable {
setConsumerOffsetMetadata(groupId, meta, consumerApiVersion)
}
- /** Requires Kafka >= 0.8.1.1 */
+ /** Requires Kafka >= 0.8.1.1. Defaults to the original ZooKeeper backed api version. */
def setConsumerOffsetMetadata(
groupId: String,
metadata: Map[TopicAndPartition, OffsetAndMetadata]
@@ -359,7 +363,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable {
}
}
-private[spark]
+@DeveloperApi
object KafkaCluster {
type Err = ArrayBuffer[Throwable]
@@ -371,7 +375,6 @@ object KafkaCluster {
)
}
- private[spark]
case class LeaderOffset(host: String, port: Int, offset: Long)
/**
@@ -379,7 +382,6 @@ object KafkaCluster {
* Simple consumers connect directly to brokers, but need many of the same configs.
* This subclass won't warn about missing ZK params, or presence of broker params.
*/
- private[spark]
class SimpleConsumerConfig private(brokers: String, originalProps: Properties)
extends ConsumerConfig(originalProps) {
val seedBrokers: Array[(String, Int)] = brokers.split(",").map { hp =>
@@ -391,7 +393,6 @@ object KafkaCluster {
}
}
- private[spark]
object SimpleConsumerConfig {
/**
* Make a consumer config without requiring group.id or zookeeper.connect,
diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml
index c4a1ae26ea699..ac2a3f65ed2f5 100644
--- a/external/mqtt-assembly/pom.xml
+++ b/external/mqtt-assembly/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-streaming-mqtt-assembly_2.10
+ spark-streaming-mqtt-assembly_2.11jarSpark Project External MQTT Assemblyhttp://spark.apache.org/
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index d3a2bf5825b08..d0d968782c7f1 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-streaming-mqtt_2.10
+ spark-streaming-mqtt_2.11streaming-mqtt
diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml
index 7b628b09ea6a5..5d4053afcbba7 100644
--- a/external/twitter/pom.xml
+++ b/external/twitter/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-streaming-twitter_2.10
+ spark-streaming-twitter_2.11streaming-twitter
diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml
index 7781aaeed9e0c..f16bc0f319744 100644
--- a/external/zeromq/pom.xml
+++ b/external/zeromq/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-streaming-zeromq_2.10
+ spark-streaming-zeromq_2.11streaming-zeromq
diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml
index 4dfe3b654df1a..0ad9c5303a36a 100644
--- a/extras/java8-tests/pom.xml
+++ b/extras/java8-tests/pom.xml
@@ -19,13 +19,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- java8-tests_2.10
+ java8-tests_2.11pomSpark Project Java8 Tests POM
diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml
index 601080c2e6fbd..d1c38c7ca5d69 100644
--- a/extras/kinesis-asl-assembly/pom.xml
+++ b/extras/kinesis-asl-assembly/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-streaming-kinesis-asl-assembly_2.10
+ spark-streaming-kinesis-asl-assembly_2.11jarSpark Project Kinesis Assemblyhttp://spark.apache.org/
diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml
index 20e2c5e0ffbee..935155eb5d362 100644
--- a/extras/kinesis-asl/pom.xml
+++ b/extras/kinesis-asl/pom.xml
@@ -19,14 +19,14 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-streaming-kinesis-asl_2.10
+ spark-streaming-kinesis-asl_2.11jarSpark Kinesis Integration
diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml
index b046a10a04d5b..bfb92791de3d8 100644
--- a/extras/spark-ganglia-lgpl/pom.xml
+++ b/extras/spark-ganglia-lgpl/pom.xml
@@ -19,14 +19,14 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-ganglia-lgpl_2.10
+ spark-ganglia-lgpl_2.11jarSpark Ganglia Integration
diff --git a/graphx/pom.xml b/graphx/pom.xml
index 388a0ef06a2b0..1813f383cdcba 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../pom.xmlorg.apache.spark
- spark-graphx_2.10
+ spark-graphx_2.11graphx
diff --git a/launcher/pom.xml b/launcher/pom.xml
index 135866cea2e74..ef731948826ef 100644
--- a/launcher/pom.xml
+++ b/launcher/pom.xml
@@ -21,13 +21,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../pom.xmlorg.apache.spark
- spark-launcher_2.10
+ spark-launcher_2.11jarSpark Project Launcherhttp://spark.apache.org/
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 42af2b8b3e411..816f3f6830382 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../pom.xmlorg.apache.spark
- spark-mllib_2.10
+ spark-mllib_2.11mllib
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 9b2340a1f16fc..ac0124513f283 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -332,12 +332,13 @@ class LogisticRegression @Since("1.2.0") (
val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
} else {
+ val standardizationParam = $(standardization)
def regParamL1Fun = (index: Int) => {
// Remove the L1 penalization on the intercept
if (index == numFeatures) {
0.0
} else {
- if ($(standardization)) {
+ if (standardizationParam) {
regParamL1
} else {
// If `standardization` is false, we still standardize the data
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index b93c9ed382bdf..e53ef300f644b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -149,9 +149,7 @@ class StopWordsRemover(override val uid: String)
val inputType = schema($(inputCol)).dataType
require(inputType.sameType(ArrayType(StringType)),
s"Input type must be ArrayType(StringType) but got $inputType.")
- val outputFields = schema.fields :+
- StructField($(outputCol), inputType, schema($(inputCol)).nullable)
- StructType(outputFields)
+ SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable)
}
override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala
new file mode 100644
index 0000000000000..6aa44e6ba723e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.optim
+
+import org.apache.spark.Logging
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.rdd.RDD
+
+/**
+ * Model fitted by [[IterativelyReweightedLeastSquares]].
+ * @param coefficients model coefficients
+ * @param intercept model intercept
+ */
+private[ml] class IterativelyReweightedLeastSquaresModel(
+ val coefficients: DenseVector,
+ val intercept: Double) extends Serializable
+
+/**
+ * Implements the method of iteratively reweighted least squares (IRLS) which is used to solve
+ * certain optimization problems by an iterative method. In each step of the iterations, it
+ * involves solving a weighted lease squares (WLS) problem by [[WeightedLeastSquares]].
+ * It can be used to find maximum likelihood estimates of a generalized linear model (GLM),
+ * find M-estimator in robust regression and other optimization problems.
+ *
+ * @param initialModel the initial guess model.
+ * @param reweightFunc the reweight function which is used to update offsets and weights
+ * at each iteration.
+ * @param fitIntercept whether to fit intercept.
+ * @param regParam L2 regularization parameter used by WLS.
+ * @param maxIter maximum number of iterations.
+ * @param tol the convergence tolerance.
+ *
+ * @see [[http://www.jstor.org/stable/2345503 P. J. Green, Iteratively Reweighted Least Squares
+ * for Maximum Likelihood Estimation, and some Robust and Resistant Alternatives,
+ * Journal of the Royal Statistical Society. Series B, 1984.]]
+ */
+private[ml] class IterativelyReweightedLeastSquares(
+ val initialModel: WeightedLeastSquaresModel,
+ val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double),
+ val fitIntercept: Boolean,
+ val regParam: Double,
+ val maxIter: Int,
+ val tol: Double) extends Logging with Serializable {
+
+ def fit(instances: RDD[Instance]): IterativelyReweightedLeastSquaresModel = {
+
+ var converged = false
+ var iter = 0
+
+ var model: WeightedLeastSquaresModel = initialModel
+ var oldModel: WeightedLeastSquaresModel = null
+
+ while (iter < maxIter && !converged) {
+
+ oldModel = model
+
+ // Update offsets and weights using reweightFunc
+ val newInstances = instances.map { instance =>
+ val (newOffset, newWeight) = reweightFunc(instance, oldModel)
+ Instance(newOffset, newWeight, instance.features)
+ }
+
+ // Estimate new model
+ model = new WeightedLeastSquares(fitIntercept, regParam, standardizeFeatures = false,
+ standardizeLabel = false).fit(newInstances)
+
+ // Check convergence
+ val oldCoefficients = oldModel.coefficients
+ val coefficients = model.coefficients
+ BLAS.axpy(-1.0, coefficients, oldCoefficients)
+ val maxTolOfCoefficients = oldCoefficients.toArray.reduce { (x, y) =>
+ math.max(math.abs(x), math.abs(y))
+ }
+ val maxTol = math.max(maxTolOfCoefficients, math.abs(oldModel.intercept - model.intercept))
+
+ if (maxTol < tol) {
+ converged = true
+ logInfo(s"IRLS converged in $iter iterations.")
+ }
+
+ logInfo(s"Iteration $iter : relative tolerance = $maxTol")
+ iter = iter + 1
+
+ if (iter == maxIter) {
+ logInfo(s"IRLS reached the max number of iterations: $maxIter.")
+ }
+
+ }
+
+ new IterativelyReweightedLeastSquaresModel(model.coefficients, model.intercept)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index 797870eb8ce8a..61b3642131810 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -31,7 +31,12 @@ import org.apache.spark.rdd.RDD
private[ml] class WeightedLeastSquaresModel(
val coefficients: DenseVector,
val intercept: Double,
- val diagInvAtWA: DenseVector) extends Serializable
+ val diagInvAtWA: DenseVector) extends Serializable {
+
+ def predict(features: Vector): Double = {
+ BLAS.dot(coefficients, features) + intercept
+ }
+}
/**
* Weighted least squares solver via normal equation.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index f48923d69974b..d7d6c0f5fa16e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -117,7 +117,9 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
}
}
- override final def toString: String = s"${parent}__$name"
+ private[this] val stringRepresentation = s"${parent}__$name"
+
+ override final def toString: String = stringRepresentation
override final def hashCode: Int = toString.##
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index c54e08b2ad9a5..e253f25c0ea65 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -219,33 +219,49 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
}
val yMean = ySummarizer.mean(0)
- val yStd = math.sqrt(ySummarizer.variance(0))
-
- // If the yStd is zero, then the intercept is yMean with zero coefficient;
- // as a result, training is not needed.
- if (yStd == 0.0) {
- logWarning(s"The standard deviation of the label is zero, so the coefficients will be " +
- s"zeros and the intercept will be the mean of the label; as a result, " +
- s"training is not needed.")
- if (handlePersistence) instances.unpersist()
- val coefficients = Vectors.sparse(numFeatures, Seq())
- val intercept = yMean
-
- val model = new LinearRegressionModel(uid, coefficients, intercept)
- // Handle possible missing or invalid prediction columns
- val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
-
- val trainingSummary = new LinearRegressionTrainingSummary(
- summaryModel.transform(dataset),
- predictionColName,
- $(labelCol),
- model,
- Array(0D),
- $(featuresCol),
- Array(0D))
- return copyValues(model.setSummary(trainingSummary))
+ val rawYStd = math.sqrt(ySummarizer.variance(0))
+ if (rawYStd == 0.0) {
+ if ($(fitIntercept) || yMean==0.0) {
+ // If the rawYStd is zero and fitIntercept=true, then the intercept is yMean with
+ // zero coefficient; as a result, training is not needed.
+ // Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of
+ // the fitIntercept
+ if (yMean == 0.0) {
+ logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " +
+ s"and the intercept will all be zero; as a result, training is not needed.")
+ } else {
+ logWarning(s"The standard deviation of the label is zero, so the coefficients will be " +
+ s"zeros and the intercept will be the mean of the label; as a result, " +
+ s"training is not needed.")
+ }
+ if (handlePersistence) instances.unpersist()
+ val coefficients = Vectors.sparse(numFeatures, Seq())
+ val intercept = yMean
+
+ val model = new LinearRegressionModel(uid, coefficients, intercept)
+ // Handle possible missing or invalid prediction columns
+ val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
+
+ val trainingSummary = new LinearRegressionTrainingSummary(
+ summaryModel.transform(dataset),
+ predictionColName,
+ $(labelCol),
+ model,
+ Array(0D),
+ $(featuresCol),
+ Array(0D))
+ return copyValues(model.setSummary(trainingSummary))
+ } else {
+ require($(regParam) == 0.0, "The standard deviation of the label is zero. " +
+ "Model cannot be regularized.")
+ logWarning(s"The standard deviation of the label is zero. " +
+ "Consider setting fitIntercept=true.")
+ }
}
+ // if y is constant (rawYStd is zero), then y cannot be scaled. In this case
+ // setting yStd=1.0 ensures that y is not scaled anymore in l-bfgs algorithm.
+ val yStd = if (rawYStd > 0) rawYStd else math.abs(yMean)
val featuresMean = featuresSummarizer.mean.toArray
val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
index e71dd9eee03e3..76021ad8f4e65 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -71,12 +71,10 @@ private[spark] object SchemaUtils {
def appendColumn(
schema: StructType,
colName: String,
- dataType: DataType): StructType = {
+ dataType: DataType,
+ nullable: Boolean = false): StructType = {
if (colName.isEmpty) return schema
- val fieldNames = schema.fieldNames
- require(!fieldNames.contains(colName), s"Column $colName already exists.")
- val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false)
- StructType(outputFields)
+ appendColumn(schema, StructField(colName, dataType, nullable))
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
index 7b203e2f40815..88dbfe3fcc9f5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
@@ -45,10 +45,10 @@ import org.apache.spark.util.Utils
* This is due to high-dimensional data (a) making it difficult to cluster at all (based
* on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions.
*
- * @param k The number of independent Gaussians in the mixture model
- * @param convergenceTol The maximum change in log-likelihood at which convergence
- * is considered to have occurred.
- * @param maxIterations The maximum number of iterations to perform
+ * @param k Number of independent Gaussians in the mixture model.
+ * @param convergenceTol Maximum change in log-likelihood at which convergence
+ * is considered to have occurred.
+ * @param maxIterations Maximum number of iterations allowed.
*/
@Since("1.3.0")
class GaussianMixture private (
@@ -108,7 +108,7 @@ class GaussianMixture private (
def getK: Int = k
/**
- * Set the maximum number of iterations to run. Default: 100
+ * Set the maximum number of iterations allowed. Default: 100
*/
@Since("1.3.0")
def setMaxIterations(maxIterations: Int): this.type = {
@@ -117,7 +117,7 @@ class GaussianMixture private (
}
/**
- * Return the maximum number of iterations to run
+ * Return the maximum number of iterations allowed
*/
@Since("1.3.0")
def getMaxIterations: Int = maxIterations
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index ca11ede4ccd47..901164a391170 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -70,13 +70,13 @@ class KMeans private (
}
/**
- * Maximum number of iterations to run.
+ * Maximum number of iterations allowed.
*/
@Since("1.4.0")
def getMaxIterations: Int = maxIterations
/**
- * Set maximum number of iterations to run. Default: 20.
+ * Set maximum number of iterations allowed. Default: 20.
*/
@Since("0.8.0")
def setMaxIterations(maxIterations: Int): this.type = {
@@ -482,12 +482,15 @@ object KMeans {
/**
* Trains a k-means model using the given set of parameters.
*
- * @param data training points stored as `RDD[Vector]`
- * @param k number of clusters
- * @param maxIterations max number of iterations
- * @param runs number of parallel runs, defaults to 1. The best model is returned.
- * @param initializationMode initialization model, either "random" or "k-means||" (default).
- * @param seed random seed value for cluster initialization
+ * @param data Training points as an `RDD` of `Vector` types.
+ * @param k Number of clusters to create.
+ * @param maxIterations Maximum number of iterations allowed.
+ * @param runs Number of runs to execute in parallel. The best model according to the cost
+ * function will be returned. (default: 1)
+ * @param initializationMode The initialization algorithm. This can either be "random" or
+ * "k-means||". (default: "k-means||")
+ * @param seed Random seed for cluster initialization. Default is to generate seed based
+ * on system time.
*/
@Since("1.3.0")
def train(
@@ -508,11 +511,13 @@ object KMeans {
/**
* Trains a k-means model using the given set of parameters.
*
- * @param data training points stored as `RDD[Vector]`
- * @param k number of clusters
- * @param maxIterations max number of iterations
- * @param runs number of parallel runs, defaults to 1. The best model is returned.
- * @param initializationMode initialization model, either "random" or "k-means||" (default).
+ * @param data Training points as an `RDD` of `Vector` types.
+ * @param k Number of clusters to create.
+ * @param maxIterations Maximum number of iterations allowed.
+ * @param runs Number of runs to execute in parallel. The best model according to the cost
+ * function will be returned. (default: 1)
+ * @param initializationMode The initialization algorithm. This can either be "random" or
+ * "k-means||". (default: "k-means||")
*/
@Since("0.8.0")
def train(
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
index eb802a365ed6e..81566b4779d66 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -61,14 +61,13 @@ class LDA private (
ldaOptimizer = new EMLDAOptimizer)
/**
- * Number of topics to infer. I.e., the number of soft cluster centers.
- *
+ * Number of topics to infer, i.e., the number of soft cluster centers.
*/
@Since("1.3.0")
def getK: Int = k
/**
- * Number of topics to infer. I.e., the number of soft cluster centers.
+ * Set the number of topics to infer, i.e., the number of soft cluster centers.
* (default = 10)
*/
@Since("1.3.0")
@@ -222,13 +221,13 @@ class LDA private (
def setBeta(beta: Double): this.type = setTopicConcentration(beta)
/**
- * Maximum number of iterations for learning.
+ * Maximum number of iterations allowed.
*/
@Since("1.3.0")
def getMaxIterations: Int = maxIterations
/**
- * Maximum number of iterations for learning.
+ * Set the maximum number of iterations allowed.
* (default = 20)
*/
@Since("1.3.0")
@@ -238,13 +237,13 @@ class LDA private (
}
/**
- * Random seed
+ * Random seed for cluster initialization.
*/
@Since("1.3.0")
def getSeed: Long = seed
/**
- * Random seed
+ * Set the random seed for cluster initialization.
*/
@Since("1.3.0")
def setSeed(seed: Long): this.type = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
index 2ab0920b06363..1ab7cb393b081 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
@@ -111,7 +111,9 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
*
* @param k Number of clusters.
* @param maxIterations Maximum number of iterations of the PIC algorithm.
- * @param initMode Initialization mode.
+ * @param initMode Set the initialization mode. This can be either "random" to use a random vector
+ * as vertex properties, or "degree" to use normalized sum similarities.
+ * Default: random.
*
* @see [[http://en.wikipedia.org/wiki/Spectral_clustering Spectral clustering (Wikipedia)]]
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 79d217e183c62..d99b89dc49ebf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -183,7 +183,7 @@ class StreamingKMeans @Since("1.2.0") (
}
/**
- * Set the decay factor directly (for forgetful algorithms).
+ * Set the forgetfulness of the previous centroids.
*/
@Since("1.2.0")
def setDecayFactor(a: Double): this.type = {
@@ -192,7 +192,9 @@ class StreamingKMeans @Since("1.2.0") (
}
/**
- * Set the half life and time unit ("batches" or "points") for forgetful algorithms.
+ * Set the half life and time unit ("batches" or "points"). If points, then the decay factor
+ * is raised to the power of number of new points and if batches, then decay factor will be
+ * used as is.
*/
@Since("1.2.0")
def setHalfLife(halfLife: Double, timeUnit: String): this.type = {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
index fb217e0c1de93..a5b24c18565b9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -89,4 +89,19 @@ class StopWordsRemoverSuite
.setCaseSensitive(true)
testDefaultReadWrite(t)
}
+
+ test("StopWordsRemover output column already exists") {
+ val outputCol = "expected"
+ val remover = new StopWordsRemover()
+ .setInputCol("raw")
+ .setOutputCol(outputCol)
+ val dataSet = sqlContext.createDataFrame(Seq(
+ (Seq("The", "the", "swift"), Seq("swift"))
+ )).toDF("raw", outputCol)
+
+ val thrown = intercept[IllegalArgumentException] {
+ testStopWordsRemover(remover, dataSet)
+ }
+ assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.")
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala
new file mode 100644
index 0000000000000..604021220a139
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.optim
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.rdd.RDD
+
+class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ private var instances1: RDD[Instance] = _
+ private var instances2: RDD[Instance] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ /*
+ R code:
+
+ A <- matrix(c(0, 1, 2, 3, 5, 2, 1, 3), 4, 2)
+ b <- c(1, 0, 1, 0)
+ w <- c(1, 2, 3, 4)
+ */
+ instances1 = sc.parallelize(Seq(
+ Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)),
+ Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)),
+ Instance(0.0, 4.0, Vectors.dense(3.0, 3.0))
+ ), 2)
+ /*
+ R code:
+
+ A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
+ b <- c(2, 8, 3, 9)
+ w <- c(1, 2, 3, 4)
+ */
+ instances2 = sc.parallelize(Seq(
+ Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)),
+ Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)),
+ Instance(9.0, 4.0, Vectors.dense(3.0, 13.0))
+ ), 2)
+ }
+
+ test("IRLS against GLM with Binomial errors") {
+ /*
+ R code:
+
+ df <- as.data.frame(cbind(A, b))
+ for (formula in c(b ~ . -1, b ~ .)) {
+ model <- glm(formula, family="binomial", data=df, weights=w)
+ print(as.vector(coef(model)))
+ }
+
+ [1] -0.30216651 -0.04452045
+ [1] 3.5651651 -1.2334085 -0.7348971
+ */
+ val expected = Seq(
+ Vectors.dense(0.0, -0.30216651, -0.04452045),
+ Vectors.dense(3.5651651, -1.2334085, -0.7348971))
+
+ import IterativelyReweightedLeastSquaresSuite._
+
+ var idx = 0
+ for (fitIntercept <- Seq(false, true)) {
+ val newInstances = instances1.map { instance =>
+ val mu = (instance.label + 0.5) / 2.0
+ val eta = math.log(mu / (1.0 - mu))
+ Instance(eta, instance.weight, instance.features)
+ }
+ val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0,
+ standardizeFeatures = false, standardizeLabel = false).fit(newInstances)
+ val irls = new IterativelyReweightedLeastSquares(initial, BinomialReweightFunc,
+ fitIntercept, regParam = 0.0, maxIter = 25, tol = 1e-8).fit(instances1)
+ val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1))
+ assert(actual ~== expected(idx) absTol 1e-4)
+ idx += 1
+ }
+ }
+
+ test("IRLS against GLM with Poisson errors") {
+ /*
+ R code:
+
+ df <- as.data.frame(cbind(A, b))
+ for (formula in c(b ~ . -1, b ~ .)) {
+ model <- glm(formula, family="poisson", data=df, weights=w)
+ print(as.vector(coef(model)))
+ }
+
+ [1] -0.09607792 0.18375613
+ [1] 6.299947 3.324107 -1.081766
+ */
+ val expected = Seq(
+ Vectors.dense(0.0, -0.09607792, 0.18375613),
+ Vectors.dense(6.299947, 3.324107, -1.081766))
+
+ import IterativelyReweightedLeastSquaresSuite._
+
+ var idx = 0
+ for (fitIntercept <- Seq(false, true)) {
+ val yMean = instances2.map(_.label).mean
+ val newInstances = instances2.map { instance =>
+ val mu = (instance.label + yMean) / 2.0
+ val eta = math.log(mu)
+ Instance(eta, instance.weight, instance.features)
+ }
+ val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0,
+ standardizeFeatures = false, standardizeLabel = false).fit(newInstances)
+ val irls = new IterativelyReweightedLeastSquares(initial, PoissonReweightFunc,
+ fitIntercept, regParam = 0.0, maxIter = 25, tol = 1e-8).fit(instances2)
+ val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1))
+ assert(actual ~== expected(idx) absTol 1e-4)
+ idx += 1
+ }
+ }
+
+ test("IRLS against L1Regression") {
+ /*
+ R code:
+
+ library(quantreg)
+
+ df <- as.data.frame(cbind(A, b))
+ for (formula in c(b ~ . -1, b ~ .)) {
+ model <- rq(formula, data=df, weights=w)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 1.266667 0.400000
+ [1] 29.5 17.0 -5.5
+ */
+ val expected = Seq(
+ Vectors.dense(0.0, 1.266667, 0.400000),
+ Vectors.dense(29.5, 17.0, -5.5))
+
+ import IterativelyReweightedLeastSquaresSuite._
+
+ var idx = 0
+ for (fitIntercept <- Seq(false, true)) {
+ val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0,
+ standardizeFeatures = false, standardizeLabel = false).fit(instances2)
+ val irls = new IterativelyReweightedLeastSquares(initial, L1RegressionReweightFunc,
+ fitIntercept, regParam = 0.0, maxIter = 200, tol = 1e-7).fit(instances2)
+ val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1))
+ assert(actual ~== expected(idx) absTol 1e-4)
+ idx += 1
+ }
+ }
+}
+
+object IterativelyReweightedLeastSquaresSuite {
+
+ def BinomialReweightFunc(
+ instance: Instance,
+ model: WeightedLeastSquaresModel): (Double, Double) = {
+ val eta = model.predict(instance.features)
+ val mu = 1.0 / (1.0 + math.exp(-1.0 * eta))
+ val z = eta + (instance.label - mu) / (mu * (1.0 - mu))
+ val w = mu * (1 - mu) * instance.weight
+ (z, w)
+ }
+
+ def PoissonReweightFunc(
+ instance: Instance,
+ model: WeightedLeastSquaresModel): (Double, Double) = {
+ val eta = model.predict(instance.features)
+ val mu = math.exp(eta)
+ val z = eta + (instance.label - mu) / mu
+ val w = mu * instance.weight
+ (z, w)
+ }
+
+ def L1RegressionReweightFunc(
+ instance: Instance,
+ model: WeightedLeastSquaresModel): (Double, Double) = {
+ val eta = model.predict(instance.features)
+ val e = math.max(math.abs(eta - instance.label), 1e-7)
+ val w = 1 / e
+ val y = instance.label
+ (y, w)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 273c882c2a47f..81fc6603ccfe6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -37,6 +37,8 @@ class LinearRegressionSuite
@transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _
@transient var datasetWithSparseFeature: DataFrame = _
@transient var datasetWithWeight: DataFrame = _
+ @transient var datasetWithWeightConstantLabel: DataFrame = _
+ @transient var datasetWithWeightZeroLabel: DataFrame = _
/*
In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
@@ -92,6 +94,29 @@ class LinearRegressionSuite
Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
Instance(29.0, 4.0, Vectors.dense(3.0, 13.0))
), 2))
+
+ /*
+ R code:
+
+ A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
+ b.const <- c(17, 17, 17, 17)
+ w <- c(1, 2, 3, 4)
+ df.const.label <- as.data.frame(cbind(A, b.const))
+ */
+ datasetWithWeightConstantLabel = sqlContext.createDataFrame(
+ sc.parallelize(Seq(
+ Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)),
+ Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)),
+ Instance(17.0, 4.0, Vectors.dense(3.0, 13.0))
+ ), 2))
+ datasetWithWeightZeroLabel = sqlContext.createDataFrame(
+ sc.parallelize(Seq(
+ Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)),
+ Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)),
+ Instance(0.0, 4.0, Vectors.dense(3.0, 13.0))
+ ), 2))
}
test("params") {
@@ -558,6 +583,86 @@ class LinearRegressionSuite
}
}
+ test("linear regression model with constant label") {
+ /*
+ R code:
+ for (formula in c(b.const ~ . -1, b.const ~ .)) {
+ model <- lm(formula, data=df.const.label, weights=w)
+ print(as.vector(coef(model)))
+ }
+ [1] -9.221298 3.394343
+ [1] 17 0 0
+ */
+ val expected = Seq(
+ Vectors.dense(0.0, -9.221298, 3.394343),
+ Vectors.dense(17.0, 0.0, 0.0))
+
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ var idx = 0
+ for (fitIntercept <- Seq(false, true)) {
+ val model1 = new LinearRegression()
+ .setFitIntercept(fitIntercept)
+ .setWeightCol("weight")
+ .setSolver(solver)
+ .fit(datasetWithWeightConstantLabel)
+ val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0),
+ model1.coefficients(1))
+ assert(actual1 ~== expected(idx) absTol 1e-4)
+
+ val model2 = new LinearRegression()
+ .setFitIntercept(fitIntercept)
+ .setWeightCol("weight")
+ .setSolver(solver)
+ .fit(datasetWithWeightZeroLabel)
+ val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0),
+ model2.coefficients(1))
+ assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4)
+ idx += 1
+ }
+ }
+ }
+
+ test("regularized linear regression through origin with constant label") {
+ // The problem is ill-defined if fitIntercept=false, regParam is non-zero.
+ // An exception is thrown in this case.
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ for (standardization <- Seq(false, true)) {
+ val model = new LinearRegression().setFitIntercept(false)
+ .setRegParam(0.1).setStandardization(standardization).setSolver(solver)
+ intercept[IllegalArgumentException] {
+ model.fit(datasetWithWeightConstantLabel)
+ }
+ }
+ }
+ }
+
+ test("linear regression with l-bfgs when training is not needed") {
+ // When label is constant, l-bfgs solver returns results without training.
+ // There are two possibilities: If the label is non-zero but constant,
+ // and fitIntercept is true, then the model return yMean as intercept without training.
+ // If label is all zeros, then all coefficients are zero regardless of fitIntercept, so
+ // no training is needed.
+ for (fitIntercept <- Seq(false, true)) {
+ for (standardization <- Seq(false, true)) {
+ val model1 = new LinearRegression()
+ .setFitIntercept(fitIntercept)
+ .setStandardization(standardization)
+ .setWeightCol("weight")
+ .setSolver("l-bfgs")
+ .fit(datasetWithWeightConstantLabel)
+ if (fitIntercept) {
+ assert(model1.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
+ }
+ val model2 = new LinearRegression()
+ .setFitIntercept(fitIntercept)
+ .setWeightCol("weight")
+ .setSolver("l-bfgs")
+ .fit(datasetWithWeightZeroLabel)
+ assert(model2.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
+ }
+ }
+ }
+
test("linear regression model training summary") {
Seq("auto", "l-bfgs", "normal").foreach { solver =>
val trainer = new LinearRegression().setSolver(solver)
diff --git a/network/common/pom.xml b/network/common/pom.xml
index eda2b7307088f..bd507c2cb6c4b 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -21,13 +21,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-network-common_2.10
+ spark-network-common_2.11jarSpark Project Networkinghttp://spark.apache.org/
diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml
index f9aa7e2dd1f43..810ec10ca05b3 100644
--- a/network/shuffle/pom.xml
+++ b/network/shuffle/pom.xml
@@ -21,13 +21,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-network-shuffle_2.10
+ spark-network-shuffle_2.11jarSpark Project Shuffle Streaming Servicehttp://spark.apache.org/
diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml
index a19cbb04b18c6..a28785b16e1e6 100644
--- a/network/yarn/pom.xml
+++ b/network/yarn/pom.xml
@@ -21,13 +21,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-network-yarn_2.10
+ spark-network-yarn_2.11jarSpark Project YARN Shuffle Servicehttp://spark.apache.org/
diff --git a/pom.xml b/pom.xml
index fb7750602c425..d0387aca66d0d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -25,7 +25,7 @@
14org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOTpomSpark Project Parent POM
@@ -165,7 +165,7 @@
3.2.22.10.5
- 2.10
+ 2.11${scala.version}org.scala-lang1.9.13
@@ -2456,7 +2456,7 @@
scala-2.10
- !scala-2.11
+ scala-2.102.10.5
@@ -2488,7 +2488,7 @@
scala-2.11
- scala-2.11
+ !scala-2.102.11.7
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index 41856443af49b..4adf64a5a0d86 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -95,7 +95,7 @@ object MimaBuild {
// because spark-streaming-mqtt(1.6.0) depends on it.
// Remove the setting on updating previousSparkVersion.
val previousSparkVersion = "1.6.0"
- val fullId = "spark-" + projectRef.project + "_2.10"
+ val fullId = "spark-" + projectRef.project + "_2.11"
mimaDefaultSettings ++
Seq(previousArtifact := Some(organization % fullId % previousSparkVersion),
binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value),
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 0b5a2e4ede6e0..9209094385395 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -45,6 +45,10 @@ object MimaExcludes {
excludePackage("org.apache.spark.sql.execution"),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.status.api.v1.ApplicationAttemptInfo.this"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.status.api.v1.ApplicationAttemptInfo.$default$5"),
// SPARK-12600 Remove SQL deprecated methods
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$QueryExecution"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$SparkPlanner"),
@@ -197,6 +201,11 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log_="),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log_"),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log__=")
+ ) ++ Seq(
+ // SPARK-12689 Migrate DDL parsing to the newly absorbed parser
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLParser"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLException"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.ddlParser")
) ++ Seq(
// SPARK-7799 Add "streaming-akka" project
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"),
@@ -217,6 +226,12 @@ object MimaExcludes {
// SPARK-11622 Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.source.libsvm.DefaultSource"),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.source.libsvm.DefaultSource.createRelation")
+ ) ++ Seq(
+ // SPARK-6363 Make Scala 2.11 the default Scala version
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.cleanup"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metadataCleaner"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnDriverEndpoint"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint")
)
case v if v.startsWith("1.6") =>
Seq(
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 4224a65a822b8..550b5bad8a46a 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -119,11 +119,11 @@ object SparkBuild extends PomBuild {
v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq
}
- if (System.getProperty("scala-2.11") == "") {
- // To activate scala-2.11 profile, replace empty property value to non-empty value
+ if (System.getProperty("scala-2.10") == "") {
+ // To activate scala-2.10 profile, replace empty property value to non-empty value
// in the same way as Maven which handles -Dname as -Dname=true before executes build process.
// see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082
- System.setProperty("scala-2.11", "true")
+ System.setProperty("scala-2.10", "true")
}
profiles
}
@@ -382,7 +382,7 @@ object OldDeps {
lazy val project = Project("oldDeps", file("dev"), settings = oldDepsSettings)
def versionArtifact(id: String): Option[sbt.ModuleID] = {
- val fullId = id + "_2.10"
+ val fullId = id + "_2.11"
Some("org.apache.spark" % fullId % "1.2.0")
}
@@ -390,7 +390,7 @@ object OldDeps {
name := "old-deps",
scalaVersion := "2.10.5",
libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq",
- "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter",
+ "spark-streaming-flume", "spark-streaming-twitter",
"spark-streaming", "spark-mllib", "spark-graphx",
"spark-core").map(versionArtifact(_).get intransitive())
)
@@ -704,7 +704,7 @@ object Java8TestSettings {
lazy val settings = Seq(
javacJVMVersion := "1.8",
// Targeting Java 8 bytecode is only supported in Scala 2.11.4 and higher:
- scalacJVMVersion := (if (System.getProperty("scala-2.11") == "true") "1.8" else "1.7")
+ scalacJVMVersion := (if (System.getProperty("scala-2.10") == "true") "1.7" else "1.8")
)
}
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index 3da36d32c5af0..ea86d6aeb8b31 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -314,3 +314,27 @@ def _copyValues(self, to, extra=None):
if p in paramMap and to.hasParam(p.name):
to._set(**{p.name: paramMap[p]})
return to
+
+ def _resetUid(self, newUid):
+ """
+ Changes the uid of this instance. This updates both
+ the stored uid and the parent uid of params and param maps.
+ This is used by persistence (loading).
+ :param newUid: new uid to use
+ :return: same instance, but with the uid and Param.parent values
+ updated, including within param maps
+ """
+ self.uid = newUid
+ newDefaultParamMap = dict()
+ newParamMap = dict()
+ for param in self.params:
+ newParam = copy.copy(param)
+ newParam.parent = newUid
+ if param in self._defaultParamMap:
+ newDefaultParamMap[newParam] = self._defaultParamMap[param]
+ if param in self._paramMap:
+ newParamMap[newParam] = self._paramMap[param]
+ param.parent = newUid
+ self._defaultParamMap = newDefaultParamMap
+ self._paramMap = newParamMap
+ return self
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 74a2248ed07c8..20dc6c2db91f3 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -18,9 +18,9 @@
import warnings
from pyspark import since
-from pyspark.ml.util import keyword_only
-from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
+from pyspark.ml.util import *
+from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.mllib.common import inherit_doc
@@ -35,7 +35,7 @@
@inherit_doc
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
- HasStandardization, HasSolver, HasWeightCol):
+ HasStandardization, HasSolver, HasWeightCol, MLWritable, MLReadable):
"""
Linear regression.
@@ -68,6 +68,25 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
Traceback (most recent call last):
...
TypeError: Method setParams forces keyword arguments.
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lr_path = path + "/lr"
+ >>> lr.save(lr_path)
+ >>> lr2 = LinearRegression.load(lr_path)
+ >>> lr2.getMaxIter()
+ 5
+ >>> model_path = path + "/lr_model"
+ >>> model.save(model_path)
+ >>> model2 = LinearRegressionModel.load(model_path)
+ >>> model.coefficients[0] == model2.coefficients[0]
+ True
+ >>> model.intercept == model2.intercept
+ True
+ >>> from shutil import rmtree
+ >>> try:
+ ... rmtree(path)
+ ... except OSError:
+ ... pass
.. versionadded:: 1.4.0
"""
@@ -106,7 +125,7 @@ def _create_model(self, java_model):
return LinearRegressionModel(java_model)
-class LinearRegressionModel(JavaModel):
+class LinearRegressionModel(JavaModel, MLWritable, MLReadable):
"""
Model fitted by LinearRegression.
@@ -821,9 +840,10 @@ def predict(self, features):
if __name__ == "__main__":
import doctest
+ import pyspark.ml.regression
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
- globs = globals().copy()
+ globs = pyspark.ml.regression.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.regression tests")
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index c45a159c460f3..54806ee336666 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -34,18 +34,22 @@
else:
import unittest
-from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
-from pyspark.sql import DataFrame, SQLContext, Row
-from pyspark.sql.functions import rand
+from shutil import rmtree
+import tempfile
+
+from pyspark.ml import Estimator, Model, Pipeline, Transformer
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import RegressionEvaluator
+from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
-from pyspark.ml.util import keyword_only
-from pyspark.ml import Estimator, Model, Pipeline, Transformer
-from pyspark.ml.feature import *
+from pyspark.ml.regression import LinearRegression
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel
+from pyspark.ml.util import keyword_only
from pyspark.mllib.linalg import DenseVector
+from pyspark.sql import DataFrame, SQLContext, Row
+from pyspark.sql.functions import rand
+from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
class MockDataset(DataFrame):
@@ -405,6 +409,26 @@ def test_fit_maximize_metric(self):
self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+class PersistenceTest(PySparkTestCase):
+
+ def test_linear_regression(self):
+ lr = LinearRegression(maxIter=1)
+ path = tempfile.mkdtemp()
+ lr_path = path + "/lr"
+ lr.save(lr_path)
+ lr2 = LinearRegression.load(lr_path)
+ self.assertEqual(lr2.uid, lr2.maxIter.parent,
+ "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)"
+ % (lr2.uid, lr2.maxIter.parent))
+ self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter],
+ "Loaded LinearRegression instance default params did not match " +
+ "original defaults")
+ try:
+ rmtree(path)
+ except OSError:
+ pass
+
+
if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner:
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index cee9d67b05325..d7a813f56cd57 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -15,8 +15,27 @@
# limitations under the License.
#
-from functools import wraps
+import sys
import uuid
+from functools import wraps
+
+if sys.version > '3':
+ basestring = str
+
+from pyspark import SparkContext, since
+from pyspark.mllib.common import inherit_doc
+
+
+def _jvm():
+ """
+ Returns the JVM view associated with SparkContext. Must be called
+ after SparkContext is initialized.
+ """
+ jvm = SparkContext._jvm
+ if jvm:
+ return jvm
+ else:
+ raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
def keyword_only(func):
@@ -52,3 +71,124 @@ def _randomUID(cls):
concatenates the class name, "_", and 12 random hex chars.
"""
return cls.__name__ + "_" + uuid.uuid4().hex[12:]
+
+
+@inherit_doc
+class JavaMLWriter(object):
+ """
+ .. note:: Experimental
+
+ Utility class that can save ML instances through their Scala implementation.
+
+ .. versionadded:: 2.0.0
+ """
+
+ def __init__(self, instance):
+ instance._transfer_params_to_java()
+ self._jwrite = instance._java_obj.write()
+
+ def save(self, path):
+ """Save the ML instance to the input path."""
+ if not isinstance(path, basestring):
+ raise TypeError("path should be a basestring, got type %s" % type(path))
+ self._jwrite.save(path)
+
+ def overwrite(self):
+ """Overwrites if the output path already exists."""
+ self._jwrite.overwrite()
+ return self
+
+ def context(self, sqlContext):
+ """Sets the SQL context to use for saving."""
+ self._jwrite.context(sqlContext._ssql_ctx)
+ return self
+
+
+@inherit_doc
+class MLWritable(object):
+ """
+ .. note:: Experimental
+
+ Mixin for ML instances that provide JavaMLWriter.
+
+ .. versionadded:: 2.0.0
+ """
+
+ def write(self):
+ """Returns an JavaMLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+
+@inherit_doc
+class JavaMLReader(object):
+ """
+ .. note:: Experimental
+
+ Utility class that can load ML instances through their Scala implementation.
+
+ .. versionadded:: 2.0.0
+ """
+
+ def __init__(self, clazz):
+ self._clazz = clazz
+ self._jread = self._load_java_obj(clazz).read()
+
+ def load(self, path):
+ """Load the ML instance from the input path."""
+ if not isinstance(path, basestring):
+ raise TypeError("path should be a basestring, got type %s" % type(path))
+ java_obj = self._jread.load(path)
+ instance = self._clazz()
+ instance._java_obj = java_obj
+ instance._resetUid(java_obj.uid())
+ instance._transfer_params_from_java()
+ return instance
+
+ def context(self, sqlContext):
+ """Sets the SQL context to use for loading."""
+ self._jread.context(sqlContext._ssql_ctx)
+ return self
+
+ @classmethod
+ def _java_loader_class(cls, clazz):
+ """
+ Returns the full class name of the Java ML instance. The default
+ implementation replaces "pyspark" by "org.apache.spark" in
+ the Python full class name.
+ """
+ java_package = clazz.__module__.replace("pyspark", "org.apache.spark")
+ return ".".join([java_package, clazz.__name__])
+
+ @classmethod
+ def _load_java_obj(cls, clazz):
+ """Load the peer Java object of the ML instance."""
+ java_class = cls._java_loader_class(clazz)
+ java_obj = _jvm()
+ for name in java_class.split("."):
+ java_obj = getattr(java_obj, name)
+ return java_obj
+
+
+@inherit_doc
+class MLReadable(object):
+ """
+ .. note:: Experimental
+
+ Mixin for instances that provide JavaMLReader.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @classmethod
+ def read(cls):
+ """Returns an JavaMLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def load(cls, path):
+ """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
+ return cls.read().load(path)
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index dd1d4b076eddd..d4d48eb2150e3 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -21,21 +21,10 @@
from pyspark.sql import DataFrame
from pyspark.ml.param import Params
from pyspark.ml.pipeline import Estimator, Transformer, Model
+from pyspark.ml.util import _jvm
from pyspark.mllib.common import inherit_doc, _java2py, _py2java
-def _jvm():
- """
- Returns the JVM view associated with SparkContext. Must be called
- after SparkContext is initialized.
- """
- jvm = SparkContext._jvm
- if jvm:
- return jvm
- else:
- raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
-
-
@inherit_doc
class JavaWrapper(Params):
"""
@@ -159,15 +148,24 @@ class JavaModel(Model, JavaTransformer):
__metaclass__ = ABCMeta
- def __init__(self, java_model):
+ def __init__(self, java_model=None):
"""
Initialize this instance with a Java model object.
Subclasses should call this constructor, initialize params,
and then call _transformer_params_from_java.
+
+ This instance can be instantiated without specifying java_model,
+ it will be assigned after that, but this scenario only used by
+ :py:class:`JavaMLReader` to load models. This is a bit of a
+ hack, but it is easiest since a proper fix would require
+ MLReader (in pyspark.ml.util) to depend on these wrappers, but
+ these wrappers depend on pyspark.ml.util (both directly and via
+ other ML classes).
"""
super(JavaModel, self).__init__()
- self._java_obj = java_model
- self.uid = java_model.uid()
+ if java_model is not None:
+ self._java_obj = java_model
+ self.uid = java_model.uid()
def copy(self, extra=None):
"""
@@ -182,8 +180,9 @@ def copy(self, extra=None):
if extra is None:
extra = dict()
that = super(JavaModel, self).copy(extra)
- that._java_obj = self._java_obj.copy(self._empty_java_param_map())
- that._transfer_params_to_java()
+ if self._java_obj is not None:
+ that._java_obj = self._java_obj.copy(self._empty_java_param_map())
+ that._transfer_params_to_java()
return that
def _call_java(self, name, *args):
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 4e9eb96fd9da1..ad04e46e8870b 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -88,8 +88,11 @@ def predict(self, x):
Find the cluster that each of the points belongs to in this
model.
- :param x: the point (or RDD of points) to determine
- compute the clusters for.
+ :param x:
+ A data point (or RDD of points) to determine cluster index.
+ :return:
+ Predicted cluster index or an RDD of predicted cluster indices
+ if the input is an RDD.
"""
if isinstance(x, RDD):
vecs = x.map(_convert_to_vector)
@@ -105,7 +108,8 @@ def computeCost(self, x):
points to their nearest center) for this model on the given
data. If provided with an RDD of points returns the sum.
- :param point: the point or RDD of points to compute the cost(s).
+ :param point:
+ A data point (or RDD of points) to compute the cost(s).
"""
if isinstance(x, RDD):
vecs = x.map(_convert_to_vector)
@@ -143,17 +147,23 @@ def train(self, rdd, k=4, maxIterations=20, minDivisibleClusterSize=1.0, seed=-1
"""
Runs the bisecting k-means algorithm return the model.
- :param rdd: input RDD to be trained on
- :param k: The desired number of leaf clusters (default: 4).
- The actual number could be smaller if there are no divisible
- leaf clusters.
- :param maxIterations: the max number of k-means iterations to
- split clusters (default: 20)
- :param minDivisibleClusterSize: the minimum number of points
- (if >= 1.0) or the minimum proportion of points (if < 1.0)
- of a divisible cluster (default: 1)
- :param seed: a random seed (default: -1888008604 from
- classOf[BisectingKMeans].getName.##)
+ :param rdd:
+ Training points as an `RDD` of `Vector` or convertible
+ sequence types.
+ :param k:
+ The desired number of leaf clusters. The actual number could
+ be smaller if there are no divisible leaf clusters.
+ (default: 4)
+ :param maxIterations:
+ Maximum number of iterations allowed to split clusters.
+ (default: 20)
+ :param minDivisibleClusterSize:
+ Minimum number of points (if >= 1.0) or the minimum proportion
+ of points (if < 1.0) of a divisible cluster.
+ (default: 1)
+ :param seed:
+ Random seed value for cluster initialization.
+ (default: -1888008604 from classOf[BisectingKMeans].getName.##)
"""
java_model = callMLlibFunc(
"trainBisectingKMeans", rdd.map(_convert_to_vector),
@@ -239,8 +249,11 @@ def predict(self, x):
Find the cluster that each of the points belongs to in this
model.
- :param x: the point (or RDD of points) to determine
- compute the clusters for.
+ :param x:
+ A data point (or RDD of points) to determine cluster index.
+ :return:
+ Predicted cluster index or an RDD of predicted cluster indices
+ if the input is an RDD.
"""
best = 0
best_distance = float("inf")
@@ -262,7 +275,8 @@ def computeCost(self, rdd):
their nearest center) for this model on the given
data.
- :param point: the RDD of points to compute the cost on.
+ :param rdd:
+ The RDD of points to compute the cost on.
"""
cost = callMLlibFunc("computeCostKmeansModel", rdd.map(_convert_to_vector),
[_convert_to_vector(c) for c in self.centers])
@@ -296,7 +310,44 @@ class KMeans(object):
@since('0.9.0')
def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||",
seed=None, initializationSteps=5, epsilon=1e-4, initialModel=None):
- """Train a k-means clustering model."""
+ """
+ Train a k-means clustering model.
+
+ :param rdd:
+ Training points as an `RDD` of `Vector` or convertible
+ sequence types.
+ :param k:
+ Number of clusters to create.
+ :param maxIterations:
+ Maximum number of iterations allowed.
+ (default: 100)
+ :param runs:
+ Number of runs to execute in parallel. The best model according
+ to the cost function will be returned (deprecated in 1.6.0).
+ (default: 1)
+ :param initializationMode:
+ The initialization algorithm. This can be either "random" or
+ "k-means||".
+ (default: "k-means||")
+ :param seed:
+ Random seed value for cluster initialization. Set as None to
+ generate seed based on system time.
+ (default: None)
+ :param initializationSteps:
+ Number of steps for the k-means|| initialization mode.
+ This is an advanced setting -- the default of 5 is almost
+ always enough.
+ (default: 5)
+ :param epsilon:
+ Distance threshold within which a center will be considered to
+ have converged. If all centers move less than this Euclidean
+ distance, iterations are stopped.
+ (default: 1e-4)
+ :param initialModel:
+ Initial cluster centers can be provided as a KMeansModel object
+ rather than using the random or k-means|| initializationModel.
+ (default: None)
+ """
if runs != 1:
warnings.warn(
"Support for runs is deprecated in 1.6.0. This param will have no effect in 2.0.0.")
@@ -415,8 +466,11 @@ def predict(self, x):
Find the cluster to which the point 'x' or each point in RDD 'x'
has maximum membership in this model.
- :param x: vector or RDD of vector represents data points.
- :return: cluster label or RDD of cluster labels.
+ :param x:
+ A feature vector or an RDD of vectors representing data points.
+ :return:
+ Predicted cluster label or an RDD of predicted cluster labels
+ if the input is an RDD.
"""
if isinstance(x, RDD):
cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z)))
@@ -430,9 +484,11 @@ def predictSoft(self, x):
"""
Find the membership of point 'x' or each point in RDD 'x' to all mixture components.
- :param x: vector or RDD of vector represents data points.
- :return: the membership value to all mixture components for vector 'x'
- or each vector in RDD 'x'.
+ :param x:
+ A feature vector or an RDD of vectors representing data points.
+ :return:
+ The membership value to all mixture components for vector 'x'
+ or each vector in RDD 'x'.
"""
if isinstance(x, RDD):
means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
@@ -447,8 +503,10 @@ def predictSoft(self, x):
def load(cls, sc, path):
"""Load the GaussianMixtureModel from disk.
- :param sc: SparkContext
- :param path: str, path to where the model is stored.
+ :param sc:
+ SparkContext.
+ :param path:
+ Path to where the model is stored.
"""
model = cls._load_java(sc, path)
wrapper = sc._jvm.GaussianMixtureModelWrapper(model)
@@ -461,19 +519,35 @@ class GaussianMixture(object):
Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm.
- :param data: RDD of data points
- :param k: Number of components
- :param convergenceTol: Threshold value to check the convergence criteria. Defaults to 1e-3
- :param maxIterations: Number of iterations. Default to 100
- :param seed: Random Seed
- :param initialModel: GaussianMixtureModel for initializing learning
-
.. versionadded:: 1.3.0
"""
@classmethod
@since('1.3.0')
def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initialModel=None):
- """Train a Gaussian Mixture clustering model."""
+ """
+ Train a Gaussian Mixture clustering model.
+
+ :param rdd:
+ Training points as an `RDD` of `Vector` or convertible
+ sequence types.
+ :param k:
+ Number of independent Gaussians in the mixture model.
+ :param convergenceTol:
+ Maximum change in log-likelihood at which convergence is
+ considered to have occurred.
+ (default: 1e-3)
+ :param maxIterations:
+ Maximum number of iterations allowed.
+ (default: 100)
+ :param seed:
+ Random seed for initial Gaussian distribution. Set as None to
+ generate seed based on system time.
+ (default: None)
+ :param initialModel:
+ Initial GMM starting point, bypassing the random
+ initialization.
+ (default: None)
+ """
initialModelWeights = None
initialModelMu = None
initialModelSigma = None
@@ -574,18 +648,24 @@ class PowerIterationClustering(object):
@since('1.5.0')
def train(cls, rdd, k, maxIterations=100, initMode="random"):
"""
- :param rdd: an RDD of (i, j, s,,ij,,) tuples representing the
- affinity matrix, which is the matrix A in the PIC paper.
- The similarity s,,ij,, must be nonnegative.
- This is a symmetric matrix and hence s,,ij,, = s,,ji,,.
- For any (i, j) with nonzero similarity, there should be
- either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input.
- Tuples with i = j are ignored, because we assume
- s,,ij,, = 0.0.
- :param k: Number of clusters.
- :param maxIterations: Maximum number of iterations of the
- PIC algorithm.
- :param initMode: Initialization mode.
+ :param rdd:
+ An RDD of (i, j, s\ :sub:`ij`\) tuples representing the
+ affinity matrix, which is the matrix A in the PIC paper. The
+ similarity s\ :sub:`ij`\ must be nonnegative. This is a symmetric
+ matrix and hence s\ :sub:`ij`\ = s\ :sub:`ji`\ For any (i, j) with
+ nonzero similarity, there should be either (i, j, s\ :sub:`ij`\) or
+ (j, i, s\ :sub:`ji`\) in the input. Tuples with i = j are ignored,
+ because it is assumed s\ :sub:`ij`\ = 0.0.
+ :param k:
+ Number of clusters.
+ :param maxIterations:
+ Maximum number of iterations of the PIC algorithm.
+ (default: 100)
+ :param initMode:
+ Initialization mode. This can be either "random" to use
+ a random vector as vertex properties, or "degree" to use
+ normalized sum similarities.
+ (default: "random")
"""
model = callMLlibFunc("trainPowerIterationClusteringModel",
rdd.map(_convert_to_vector), int(k), int(maxIterations), initMode)
@@ -625,8 +705,10 @@ class StreamingKMeansModel(KMeansModel):
and new data. If it set to zero, the old centroids are completely
forgotten.
- :param clusterCenters: Initial cluster centers.
- :param clusterWeights: List of weights assigned to each cluster.
+ :param clusterCenters:
+ Initial cluster centers.
+ :param clusterWeights:
+ List of weights assigned to each cluster.
>>> initCenters = [[0.0, 0.0], [1.0, 1.0]]
>>> initWeights = [1.0, 1.0]
@@ -673,11 +755,14 @@ def clusterWeights(self):
def update(self, data, decayFactor, timeUnit):
"""Update the centroids, according to data
- :param data: Should be a RDD that represents the new data.
- :param decayFactor: forgetfulness of the previous centroids.
- :param timeUnit: Can be "batches" or "points". If points, then the
- decay factor is raised to the power of number of new
- points and if batches, it is used as it is.
+ :param data:
+ RDD with new data for the model update.
+ :param decayFactor:
+ Forgetfulness of the previous centroids.
+ :param timeUnit:
+ Can be "batches" or "points". If points, then the decay factor
+ is raised to the power of number of new points and if batches,
+ then decay factor will be used as is.
"""
if not isinstance(data, RDD):
raise TypeError("Data should be of an RDD, got %s." % type(data))
@@ -704,10 +789,17 @@ class StreamingKMeans(object):
More details on how the centroids are updated are provided under the
docs of StreamingKMeansModel.
- :param k: int, number of clusters
- :param decayFactor: float, forgetfulness of the previous centroids.
- :param timeUnit: can be "batches" or "points". If points, then the
- decayfactor is raised to the power of no. of new points.
+ :param k:
+ Number of clusters.
+ (default: 2)
+ :param decayFactor:
+ Forgetfulness of the previous centroids.
+ (default: 1.0)
+ :param timeUnit:
+ Can be "batches" or "points". If points, then the decay factor is
+ raised to the power of number of new points and if batches, then
+ decay factor will be used as is.
+ (default: "batches")
.. versionadded:: 1.5.0
"""
@@ -870,11 +962,13 @@ def describeTopics(self, maxTermsPerTopic=None):
WARNING: If vocabSize and k are large, this can return a large object!
- :param maxTermsPerTopic: Maximum number of terms to collect for each topic.
- (default: vocabulary size)
- :return: Array over topics. Each topic is represented as a pair of matching arrays:
- (term indices, term weights in topic).
- Each topic's terms are sorted in order of decreasing weight.
+ :param maxTermsPerTopic:
+ Maximum number of terms to collect for each topic.
+ (default: vocabulary size)
+ :return:
+ Array over topics. Each topic is represented as a pair of
+ matching arrays: (term indices, term weights in topic).
+ Each topic's terms are sorted in order of decreasing weight.
"""
if maxTermsPerTopic is None:
topics = self.call("describeTopics")
@@ -887,8 +981,10 @@ def describeTopics(self, maxTermsPerTopic=None):
def load(cls, sc, path):
"""Load the LDAModel from disk.
- :param sc: SparkContext
- :param path: str, path to where the model is stored.
+ :param sc:
+ SparkContext.
+ :param path:
+ Path to where the model is stored.
"""
if not isinstance(sc, SparkContext):
raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
@@ -909,17 +1005,38 @@ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0,
topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"):
"""Train a LDA model.
- :param rdd: RDD of data points
- :param k: Number of clusters you want
- :param maxIterations: Number of iterations. Default to 20
- :param docConcentration: Concentration parameter (commonly named "alpha")
- for the prior placed on documents' distributions over topics ("theta").
- :param topicConcentration: Concentration parameter (commonly named "beta" or "eta")
- for the prior placed on topics' distributions over terms.
- :param seed: Random Seed
- :param checkpointInterval: Period (in iterations) between checkpoints.
- :param optimizer: LDAOptimizer used to perform the actual calculation.
- Currently "em", "online" are supported. Default to "em".
+ :param rdd:
+ RDD of documents, which are tuples of document IDs and term
+ (word) count vectors. The term count vectors are "bags of
+ words" with a fixed-size vocabulary (where the vocabulary size
+ is the length of the vector). Document IDs must be unique
+ and >= 0.
+ :param k:
+ Number of topics to infer, i.e., the number of soft cluster
+ centers.
+ (default: 10)
+ :param maxIterations:
+ Maximum number of iterations allowed.
+ (default: 20)
+ :param docConcentration:
+ Concentration parameter (commonly named "alpha") for the prior
+ placed on documents' distributions over topics ("theta").
+ (default: -1.0)
+ :param topicConcentration:
+ Concentration parameter (commonly named "beta" or "eta") for
+ the prior placed on topics' distributions over terms.
+ (default: -1.0)
+ :param seed:
+ Random seed for cluster initialization. Set as None to generate
+ seed based on system time.
+ (default: None)
+ :param checkpointInterval:
+ Period (in iterations) between checkpoints.
+ (default: 10)
+ :param optimizer:
+ LDAOptimizer used to perform the actual calculation. Currently
+ "em", "online" are supported.
+ (default: "em")
"""
model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations,
docConcentration, topicConcentration, seed,
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 13b3397501c0b..4dd7083d79c8c 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -219,8 +219,10 @@ class LinearRegressionWithSGD(object):
"""
Train a linear regression model with no regularization using Stochastic Gradient Descent.
This solves the least squares regression formulation
- f(weights) = 1/n ||A weights-y||^2^
- (which is the mean squared error).
+
+ f(weights) = 1/n ||A weights-y||^2
+
+ which is the mean squared error.
Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
its corresponding right hand side label y.
See also the documentation for the precise formulation.
@@ -367,8 +369,10 @@ def load(cls, sc, path):
class LassoWithSGD(object):
"""
Train a regression model with L1-regularization using Stochastic Gradient Descent.
- This solves the l1-regularized least squares regression formulation
- f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1
+ This solves the L1-regularized least squares regression formulation
+
+ f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1
+
Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
its corresponding right hand side label y.
See also the documentation for the precise formulation.
@@ -505,8 +509,10 @@ def load(cls, sc, path):
class RidgeRegressionWithSGD(object):
"""
Train a regression model with L2-regularization using Stochastic Gradient Descent.
- This solves the l2-regularized least squares regression formulation
- f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^
+ This solves the L2-regularized least squares regression formulation
+
+ f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2
+
Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
its corresponding right hand side label y.
See also the documentation for the precise formulation.
@@ -655,17 +661,19 @@ class IsotonicRegression(object):
Only univariate (single feature) algorithm supported.
Sequential PAV implementation based on:
- Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani.
+
+ Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani.
"Nearly-isotonic regression." Technometrics 53.1 (2011): 54-61.
- Available from [[http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf]]
+ Available from http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf
Sequential PAV parallelization based on:
- Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset.
- "An approach to parallelizing isotonic regression."
- Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147.
- Available from [[http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf]]
- @see [[http://en.wikipedia.org/wiki/Isotonic_regression Isotonic regression (Wikipedia)]]
+ Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset.
+ "An approach to parallelizing isotonic regression."
+ Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147.
+ Available from http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf
+
+ See `Isotonic regression (Wikipedia) `_.
.. versionadded:: 1.4.0
"""
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index c28594625457a..fe2264a63cf30 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -426,6 +426,9 @@ def takeSample(self, withReplacement, num, seed=None):
"""
Return a fixed-size sampled subset of this RDD.
+ Note that this method should only be used if the resulting array is expected
+ to be small, as all the data is loaded into the driver's memory.
+
>>> rdd = sc.parallelize(range(0, 10))
>>> len(rdd.takeSample(True, 20, 1))
20
@@ -766,6 +769,8 @@ def func(it):
def collect(self):
"""
Return a list that contains all of the elements in this RDD.
+ Note that this method should only be used if the resulting array is expected
+ to be small, as all the data is loaded into the driver's memory.
"""
with SCCallSiteSync(self.context) as css:
port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
@@ -1213,6 +1218,9 @@ def top(self, num, key=None):
"""
Get the top N elements from a RDD.
+ Note that this method should only be used if the resulting array is expected
+ to be small, as all the data is loaded into the driver's memory.
+
Note: It returns the list sorted in descending order.
>>> sc.parallelize([10, 4, 2, 12, 3]).top(1)
@@ -1235,6 +1243,9 @@ def takeOrdered(self, num, key=None):
Get the N elements from a RDD ordered in ascending order or as
specified by the optional key function.
+ Note that this method should only be used if the resulting array is expected
+ to be small, as all the data is loaded into the driver's memory.
+
>>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6)
[1, 2, 3, 4, 5, 6]
>>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x)
@@ -1254,6 +1265,9 @@ def take(self, num):
that partition to estimate the number of additional partitions needed
to satisfy the limit.
+ Note that this method should only be used if the resulting array is expected
+ to be small, as all the data is loaded into the driver's memory.
+
Translated from the Scala implementation in RDD#take().
>>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2)
@@ -1511,6 +1525,9 @@ def collectAsMap(self):
"""
Return the key-value pairs in this RDD to the master as a dictionary.
+ Note that this method should only be used if the resulting data is expected
+ to be small, as all the data is loaded into the driver's memory.
+
>>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap()
>>> m[1]
2
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 90a6b5d9c0dda..3a8c8305ee3d8 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -739,6 +739,9 @@ def describe(self, *cols):
def head(self, n=None):
"""Returns the first ``n`` rows.
+ Note that this method should only be used if the resulting array is expected
+ to be small, as all the data is loaded into the driver's memory.
+
:param n: int, default 1. Number of rows to return.
:return: If n is greater than 1, return a list of :class:`Row`.
If n is 1, return a single Row.
@@ -1330,6 +1333,9 @@ def toDF(self, *cols):
def toPandas(self):
"""Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
+ Note that this method should only be used if the resulting Pandas's DataFrame is expected
+ to be small, as all the data is loaded into the driver's memory.
+
This is only available if Pandas is installed and available.
>>> df.toPandas() # doctest: +SKIP
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 719eca8f5559e..0d5708526701e 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -81,8 +81,6 @@ def _():
'max': 'Aggregate function: returns the maximum value of the expression in a group.',
'min': 'Aggregate function: returns the minimum value of the expression in a group.',
- 'first': 'Aggregate function: returns the first value in a group.',
- 'last': 'Aggregate function: returns the last value in a group.',
'count': 'Aggregate function: returns the number of items in a group.',
'sum': 'Aggregate function: returns the sum of all values in the expression.',
'avg': 'Aggregate function: returns the average of the values in a group.',
@@ -278,6 +276,18 @@ def countDistinct(col, *cols):
return Column(jc)
+@since(1.3)
+def first(col, ignorenulls=False):
+ """Aggregate function: returns the first value in a group.
+
+ The function by default returns the first values it sees. It will return the first non-null
+ value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls)
+ return Column(jc)
+
+
@since(1.6)
def input_file_name():
"""Creates a string column for the file name of the current Spark task.
@@ -310,6 +320,18 @@ def isnull(col):
return Column(sc._jvm.functions.isnull(_to_java_column(col)))
+@since(1.3)
+def last(col, ignorenulls=False):
+ """Aggregate function: returns the last value in a group.
+
+ The function by default returns the last values it sees. It will return the last non-null
+ value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls)
+ return Column(jc)
+
+
@since(1.6)
def monotonically_increasing_id():
"""A column that generates monotonically increasing 64-bit integers.
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 0b20022b14b8d..b1453c637f79e 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -152,6 +152,8 @@ def json(self, path, schema=None):
You can set the following JSON-specific options to deal with non-standard JSON files:
* ``primitivesAsString`` (default ``false``): infers all primitive values as a string \
type
+ * `floatAsBigDecimal` (default `false`): infers all floating-point values as a decimal \
+ type
* ``allowComments`` (default ``false``): ignores Java/C++ style comment in JSON records
* ``allowUnquotedFieldNames`` (default ``false``): allows unquoted JSON field names
* ``allowSingleQuotes`` (default ``true``): allows single quotes in addition to double \
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 410efbafe0792..e30aa0a796924 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -641,6 +641,16 @@ def test_aggregator(self):
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
+ def test_first_last_ignorenulls(self):
+ from pyspark.sql import functions
+ df = self.sqlCtx.range(0, 100)
+ df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id"))
+ df3 = df2.select(functions.first(df2.id, False).alias('a'),
+ functions.first(df2.id, True).alias('b'),
+ functions.last(df2.id, False).alias('c'),
+ functions.last(df2.id, True).alias('d'))
+ self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
+
def test_corr(self):
import math
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 24b812615cbb4..b33e8252a7d32 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -1013,12 +1013,12 @@ def setUp(self):
self._kafkaTestUtils.setup()
def tearDown(self):
+ super(KafkaStreamTests, self).tearDown()
+
if self._kafkaTestUtils is not None:
self._kafkaTestUtils.teardown()
self._kafkaTestUtils = None
- super(KafkaStreamTests, self).tearDown()
-
def _randomTopic(self):
return "topic-%d" % random.randint(0, 10000)
diff --git a/repl/pom.xml b/repl/pom.xml
index efc3dd452e329..0f396c9b809bd 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../pom.xmlorg.apache.spark
- spark-repl_2.10
+ spark-repl_2.11jarSpark Project REPLhttp://spark.apache.org/
@@ -159,7 +159,7 @@
scala-2.10
- !scala-2.11
+ scala-2.10
@@ -173,7 +173,7 @@
scala-2.11
- scala-2.11
+ !scala-2.10scala-2.11/src/main/scala
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
index bb3081d12938e..07ba28bb07545 100644
--- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
@@ -33,7 +33,8 @@ object Main extends Logging {
var sparkContext: SparkContext = _
var sqlContext: SQLContext = _
- var interp = new SparkILoop // this is a public var because tests reset it.
+ // this is a public var because tests reset it.
+ var interp: SparkILoop = _
private var hasErrors = false
@@ -43,6 +44,12 @@ object Main extends Logging {
}
def main(args: Array[String]) {
+ doMain(args, new SparkILoop)
+ }
+
+ // Visible for testing
+ private[repl] def doMain(args: Array[String], _interp: SparkILoop): Unit = {
+ interp = _interp
val interpArguments = List(
"-Yrepl-class-based",
"-Yrepl-outdir", s"${outputDir.getAbsolutePath}",
diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 63f3688c9e612..b9ed79da421a6 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -50,12 +50,7 @@ class ReplSuite extends SparkFunSuite {
System.setProperty(CONF_EXECUTOR_CLASSPATH, classpath)
System.setProperty("spark.master", master)
- val interp = {
- new SparkILoop(in, new PrintWriter(out))
- }
- org.apache.spark.repl.Main.interp = interp
- Main.main(Array("-classpath", classpath)) // call main
- org.apache.spark.repl.Main.interp = null
+ Main.doMain(Array("-classpath", classpath), new SparkILoop(in, new PrintWriter(out)))
if (oldExecutorClasspath != null) {
System.setProperty(CONF_EXECUTOR_CLASSPATH, oldExecutorClasspath)
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 76ca3f3bb1bfa..c2ad9b99f3ac9 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -21,13 +21,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-catalyst_2.10
+ spark-catalyst_2.11jarSpark Project Catalysthttp://spark.apache.org/
@@ -127,13 +127,4 @@
-
-
-
- scala-2.10
-
- !scala-2.11
-
-
-
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g
index 0555a6ba83cbb..c162c1a0c5789 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g
@@ -493,6 +493,16 @@ descFuncNames
| functionIdentifier
;
+//We are allowed to use From and To in CreateTableUsing command's options (actually seems we can use any string as the option key). But we can't simply add them into nonReserved because by doing that we mess other existing rules. So we create a looseIdentifier and looseNonReserved here.
+looseIdentifier
+ :
+ Identifier
+ | looseNonReserved -> Identifier[$looseNonReserved.text]
+ // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false,
+ // the sql11keywords in existing q tests will NOT be added back.
+ | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text]
+ ;
+
identifier
:
Identifier
@@ -516,6 +526,10 @@ principalIdentifier
| QuotedIdentifier
;
+looseNonReserved
+ : nonReserved | KW_FROM | KW_TO
+ ;
+
//The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved
//Non reserved keywords are basically the keywords that can be used as identifiers.
//All the KW_* are automatically not only keywords, but also reserved keywords.
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
index 6d76afcd4ac07..e83f8a7cd1b5c 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
@@ -117,15 +117,20 @@ joinToken
@init { gParent.pushMsg("join type specifier", state); }
@after { gParent.popMsg(state); }
:
- KW_JOIN -> TOK_JOIN
- | KW_INNER KW_JOIN -> TOK_JOIN
- | COMMA -> TOK_JOIN
- | KW_CROSS KW_JOIN -> TOK_CROSSJOIN
- | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN
- | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN
- | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN
- | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN
- | KW_ANTI KW_JOIN -> TOK_ANTIJOIN
+ KW_JOIN -> TOK_JOIN
+ | KW_INNER KW_JOIN -> TOK_JOIN
+ | KW_NATURAL KW_JOIN -> TOK_NATURALJOIN
+ | KW_NATURAL KW_INNER KW_JOIN -> TOK_NATURALJOIN
+ | COMMA -> TOK_JOIN
+ | KW_CROSS KW_JOIN -> TOK_CROSSJOIN
+ | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN
+ | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN
+ | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN
+ | KW_NATURAL KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_NATURALLEFTOUTERJOIN
+ | KW_NATURAL KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_NATURALRIGHTOUTERJOIN
+ | KW_NATURAL KW_FULL (KW_OUTER)? KW_JOIN -> TOK_NATURALFULLOUTERJOIN
+ | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN
+ | KW_ANTI KW_JOIN -> TOK_ANTIJOIN
;
lateralView
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g
index 4374cd7ef7200..fd1ad59207e31 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g
@@ -324,6 +324,8 @@ KW_ISOLATION: 'ISOLATION';
KW_LEVEL: 'LEVEL';
KW_SNAPSHOT: 'SNAPSHOT';
KW_AUTOCOMMIT: 'AUTOCOMMIT';
+KW_REFRESH: 'REFRESH';
+KW_OPTIONS: 'OPTIONS';
KW_WEEK: 'WEEK'|'WEEKS';
KW_MILLISECOND: 'MILLISECOND'|'MILLISECONDS';
KW_MICROSECOND: 'MICROSECOND'|'MICROSECONDS';
@@ -333,6 +335,8 @@ KW_CACHE: 'CACHE';
KW_UNCACHE: 'UNCACHE';
KW_DFS: 'DFS';
+KW_NATURAL: 'NATURAL';
+
// Operators
// NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work.
@@ -470,7 +474,7 @@ Identifier
fragment
QuotedIdentifier
:
- '`' ( '``' | ~('`') )* '`' { setText(getText().substring(1, getText().length() -1 ).replaceAll("``", "`")); }
+ '`' ( '``' | ~('`') )* '`' { setText(getText().replaceAll("``", "`")); }
;
WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;}
@@ -481,3 +485,7 @@ COMMENT
{ $channel=HIDDEN; }
;
+/* Prevent that the lexer swallows unknown characters. */
+ANY
+ :.
+ ;
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
index 35bef00351d72..9935678ca2ca2 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
@@ -96,6 +96,10 @@ TOK_RIGHTOUTERJOIN;
TOK_FULLOUTERJOIN;
TOK_UNIQUEJOIN;
TOK_CROSSJOIN;
+TOK_NATURALJOIN;
+TOK_NATURALLEFTOUTERJOIN;
+TOK_NATURALRIGHTOUTERJOIN;
+TOK_NATURALFULLOUTERJOIN;
TOK_LOAD;
TOK_EXPORT;
TOK_IMPORT;
@@ -142,6 +146,7 @@ TOK_UNIONTYPE;
TOK_COLTYPELIST;
TOK_CREATEDATABASE;
TOK_CREATETABLE;
+TOK_CREATETABLEUSING;
TOK_TRUNCATETABLE;
TOK_CREATEINDEX;
TOK_CREATEINDEX_INDEXTBLNAME;
@@ -371,6 +376,10 @@ TOK_TXN_READ_WRITE;
TOK_COMMIT;
TOK_ROLLBACK;
TOK_SET_AUTOCOMMIT;
+TOK_REFRESHTABLE;
+TOK_TABLEPROVIDER;
+TOK_TABLEOPTIONS;
+TOK_TABLEOPTION;
TOK_CACHETABLE;
TOK_UNCACHETABLE;
TOK_CLEARCACHE;
@@ -660,6 +669,12 @@ import java.util.HashMap;
}
private char [] excludedCharForColumnName = {'.', ':'};
private boolean containExcludedCharForCreateTableColumnName(String input) {
+ if (input.length() > 0) {
+ if (input.charAt(0) == '`' && input.charAt(input.length() - 1) == '`') {
+ // When column name is backquoted, we don't care about excluded chars.
+ return false;
+ }
+ }
for(char c : excludedCharForColumnName) {
if(input.indexOf(c)>-1) {
return true;
@@ -781,6 +796,7 @@ ddlStatement
| truncateTableStatement
| alterStatement
| descStatement
+ | refreshStatement
| showStatement
| metastoreCheck
| createViewStatement
@@ -907,12 +923,31 @@ createTableStatement
@init { pushMsg("create table statement", state); }
@after { popMsg(state); }
: KW_CREATE (temp=KW_TEMPORARY)? (ext=KW_EXTERNAL)? KW_TABLE ifNotExists? name=tableName
- ( like=KW_LIKE likeName=tableName
+ (
+ like=KW_LIKE likeName=tableName
tableRowFormat?
tableFileFormat?
tableLocation?
tablePropertiesPrefixed?
+ -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists?
+ ^(TOK_LIKETABLE $likeName?)
+ tableRowFormat?
+ tableFileFormat?
+ tableLocation?
+ tablePropertiesPrefixed?
+ )
+ |
+ tableProvider
+ tableOpts?
+ (KW_AS selectStatementWithCTE)?
+ -> ^(TOK_CREATETABLEUSING $name $temp? ifNotExists?
+ tableProvider
+ tableOpts?
+ selectStatementWithCTE?
+ )
| (LPAREN columnNameTypeList RPAREN)?
+ (p=tableProvider?)
+ tableOpts?
tableComment?
tablePartition?
tableBuckets?
@@ -922,8 +957,15 @@ createTableStatement
tableLocation?
tablePropertiesPrefixed?
(KW_AS selectStatementWithCTE)?
- )
- -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists?
+ -> {p != null}?
+ ^(TOK_CREATETABLEUSING $name $temp? ifNotExists?
+ columnNameTypeList?
+ $p
+ tableOpts?
+ selectStatementWithCTE?
+ )
+ ->
+ ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists?
^(TOK_LIKETABLE $likeName?)
columnNameTypeList?
tableComment?
@@ -935,7 +977,8 @@ createTableStatement
tableLocation?
tablePropertiesPrefixed?
selectStatementWithCTE?
- )
+ )
+ )
;
truncateTableStatement
@@ -1379,6 +1422,13 @@ tabPartColTypeExpr
: tableName partitionSpec? extColumnName? -> ^(TOK_TABTYPE tableName partitionSpec? extColumnName?)
;
+refreshStatement
+@init { pushMsg("refresh statement", state); }
+@after { popMsg(state); }
+ :
+ KW_REFRESH KW_TABLE tableName -> ^(TOK_REFRESHTABLE tableName)
+ ;
+
descStatement
@init { pushMsg("describe statement", state); }
@after { popMsg(state); }
@@ -1774,6 +1824,30 @@ showStmtIdentifier
| StringLiteral
;
+tableProvider
+@init { pushMsg("table's provider", state); }
+@after { popMsg(state); }
+ :
+ KW_USING Identifier (DOT Identifier)*
+ -> ^(TOK_TABLEPROVIDER Identifier+)
+ ;
+
+optionKeyValue
+@init { pushMsg("table's option specification", state); }
+@after { popMsg(state); }
+ :
+ (looseIdentifier (DOT looseIdentifier)*) StringLiteral
+ -> ^(TOK_TABLEOPTION looseIdentifier+ StringLiteral)
+ ;
+
+tableOpts
+@init { pushMsg("table's options", state); }
+@after { popMsg(state); }
+ :
+ KW_OPTIONS LPAREN optionKeyValue (COMMA optionKeyValue)* RPAREN
+ -> ^(TOK_TABLEOPTIONS optionKeyValue+)
+ ;
+
tableComment
@init { pushMsg("table's comment", state); }
@after { popMsg(state); }
@@ -2132,7 +2206,7 @@ structType
mapType
@init { pushMsg("map type", state); }
@after { popMsg(state); }
- : KW_MAP LESSTHAN left=primitiveType COMMA right=type GREATERTHAN
+ : KW_MAP LESSTHAN left=type COMMA right=type GREATERTHAN
-> ^(TOK_MAP $left $right)
;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
index 536c292ab7f34..a42360d5629f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
@@ -140,6 +140,7 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends
case Token("TOK_BOOLEAN", Nil) => BooleanType
case Token("TOK_STRING", Nil) => StringType
case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType
+ case Token("TOK_CHAR", Token(_, Nil) :: Nil) => StringType
case Token("TOK_FLOAT", Nil) => FloatType
case Token("TOK_DOUBLE", Nil) => DoubleType
case Token("TOK_DATE", Nil) => DateType
@@ -156,9 +157,10 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends
protected def nodeToStructField(node: ASTNode): StructField = node match {
case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: Nil) =>
- StructField(fieldName, nodeToDataType(dataType), nullable = true)
- case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: _ /* comment */:: Nil) =>
- StructField(fieldName, nodeToDataType(dataType), nullable = true)
+ StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true)
+ case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: comment :: Nil) =>
+ val meta = new MetadataBuilder().putString("comment", unquoteString(comment.text)).build()
+ StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true, meta)
case _ =>
noParseRule("StructField", node)
}
@@ -222,15 +224,16 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Nil =>
ShowFunctions(None, None)
case Token(name, Nil) :: Nil =>
- ShowFunctions(None, Some(unquoteString(name)))
+ ShowFunctions(None, Some(unquoteString(cleanIdentifier(name))))
case Token(db, Nil) :: Token(name, Nil) :: Nil =>
- ShowFunctions(Some(unquoteString(db)), Some(unquoteString(name)))
+ ShowFunctions(Some(unquoteString(cleanIdentifier(db))),
+ Some(unquoteString(cleanIdentifier(name))))
case _ =>
noParseRule("SHOW FUNCTIONS", node)
}
case Token("TOK_DESCFUNCTION", Token(functionName, Nil) :: isExtended) =>
- DescribeFunction(functionName, isExtended.nonEmpty)
+ DescribeFunction(cleanIdentifier(functionName), isExtended.nonEmpty)
case Token("TOK_QUERY", queryArgs @ Token("TOK_CTE" | "TOK_FROM" | "TOK_INSERT", _) :: _) =>
val (fromClause: Option[ASTNode], insertClauses, cteRelations) =
@@ -517,6 +520,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case "TOK_LEFTSEMIJOIN" => LeftSemi
case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node)
case "TOK_ANTIJOIN" => noParseRule("Anti Join", node)
+ case "TOK_NATURALJOIN" => NaturalJoin(Inner)
+ case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter)
+ case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter)
+ case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter)
}
Join(nodeToRelation(relation1),
nodeToRelation(relation2),
@@ -611,7 +618,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
noParseRule("Select", node)
}
- protected val escapedIdentifier = "`([^`]+)`".r
+ protected val escapedIdentifier = "`(.+)`".r
protected val doubleQuotedString = "\"([^\"]+)\"".r
protected val singleQuotedString = "'([^']+)'".r
@@ -655,7 +662,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
nodeToExpr(qualifier) match {
case UnresolvedAttribute(nameParts) =>
UnresolvedAttribute(nameParts :+ cleanIdentifier(attr))
- case other => UnresolvedExtractValue(other, Literal(attr))
+ case other => UnresolvedExtractValue(other, Literal(cleanIdentifier(attr)))
}
/* Stars (*) */
@@ -663,7 +670,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
// The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only
// has a single child which is tableName.
case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", target) :: Nil) if target.nonEmpty =>
- UnresolvedStar(Some(target.map(_.text)))
+ UnresolvedStar(Some(target.map(x => cleanIdentifier(x.text))))
/* Aggregate Functions */
case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) =>
@@ -971,7 +978,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
protected def nodeToGenerate(node: ASTNode, outer: Boolean, child: LogicalPlan): Generate = {
val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = node
- val alias = getClause("TOK_TABALIAS", clauses).children.head.text
+ val alias = cleanIdentifier(getClause("TOK_TABALIAS", clauses).children.head.text)
val generator = clauses.head match {
case Token("TOK_FUNCTION", Token(explode(), Nil) :: childNode :: Nil) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 3c3717d5043aa..59ee41d02f198 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -292,7 +292,7 @@ object JavaTypeInference {
val setter = if (nullable) {
constructor
} else {
- AssertNotNull(constructor, other.getName, fieldName, fieldType.toString)
+ AssertNotNull(constructor, Seq("currently no type path record in java"))
}
p.getWriteMethod.getName -> setter
}.toMap
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 643228d0eb27d..02cb2d9a2b118 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -249,6 +249,8 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
+
+ // TODO: add runtime null check for primitive array
val primitiveMethod = elementType match {
case t if t <:< definitions.IntTpe => Some("toIntArray")
case t if t <:< definitions.LongTpe => Some("toLongArray")
@@ -276,22 +278,29 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
+ val Schema(dataType, nullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
- val arrayData =
- Invoke(
- MapObjects(
- p => constructorFor(elementType, Some(p), newTypePath),
- getPath,
- schemaFor(elementType).dataType),
- "array",
- ObjectType(classOf[Array[Any]]))
+
+ val mapFunction: Expression => Expression = p => {
+ val converter = constructorFor(elementType, Some(p), newTypePath)
+ if (nullable) {
+ converter
+ } else {
+ AssertNotNull(converter, newTypePath)
+ }
+ }
+
+ val array = Invoke(
+ MapObjects(mapFunction, getPath, dataType),
+ "array",
+ ObjectType(classOf[Array[Any]]))
StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
- arrayData :: Nil)
+ array :: Nil)
case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
@@ -343,7 +352,7 @@ object ScalaReflection extends ScalaReflection {
newTypePath)
if (!nullable) {
- AssertNotNull(constructor, t.toString, fieldName, fieldType.toString)
+ AssertNotNull(constructor, newTypePath)
} else {
constructor
}
@@ -601,6 +610,20 @@ object ScalaReflection extends ScalaReflection {
getConstructorParameters(t)
}
+ /**
+ * Returns the parameter names for the primary constructor of this class.
+ *
+ * Logically we should call `getConstructorParameters` and throw away the parameter types to get
+ * parameter names, however there are some weird scala reflection problems and this method is a
+ * workaround to avoid getting parameter types.
+ */
+ def getConstructorParameterNames(cls: Class[_]): Seq[String] = {
+ val m = runtimeMirror(cls.getClassLoader)
+ val classSymbol = m.staticClass(cls.getName)
+ val t = classSymbol.selfType
+ constructParams(t).map(_.name.toString)
+ }
+
def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
}
@@ -745,6 +768,12 @@ trait ScalaReflection {
def getConstructorParameters(tpe: Type): Seq[(String, Type)] = {
val formalTypeArgs = tpe.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = tpe
+ constructParams(tpe).map { p =>
+ p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+ }
+ }
+
+ protected def constructParams(tpe: Type): Seq[Symbol] = {
val constructorSymbol = tpe.member(nme.CONSTRUCTOR)
val params = if (constructorSymbol.isMethod) {
constructorSymbol.asMethod.paramss
@@ -758,9 +787,6 @@ trait ScalaReflection {
primaryConstructorSymbol.get.asMethod.paramss
}
}
-
- params.flatten.map { p =>
- p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
- }
+ params.flatten
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 33d76eeb21287..4d53b232d5510 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,12 +17,15 @@
package org.apache.spark.sql.catalyst.analysis
+import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf}
+import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
@@ -80,6 +83,7 @@ class Analyzer(
ResolveAliases ::
ResolveWindowOrder ::
ResolveWindowFrame ::
+ ResolveNaturalJoin ::
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
@@ -344,6 +348,63 @@ class Analyzer(
}
}
+ /**
+ * Generate a new logical plan for the right child with different expression IDs
+ * for all conflicting attributes.
+ */
+ private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = {
+ val conflictingAttributes = left.outputSet.intersect(right.outputSet)
+ logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " +
+ s"between $left and $right")
+
+ right.collect {
+ // Handle base relations that might appear more than once.
+ case oldVersion: MultiInstanceRelation
+ if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
+ val newVersion = oldVersion.newInstance()
+ (oldVersion, newVersion)
+
+ // Handle projects that create conflicting aliases.
+ case oldVersion @ Project(projectList, _)
+ if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
+ (oldVersion, oldVersion.copy(projectList = newAliases(projectList)))
+
+ case oldVersion @ Aggregate(_, aggregateExpressions, _)
+ if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
+ (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
+
+ case oldVersion: Generate
+ if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
+ val newOutput = oldVersion.generatorOutput.map(_.newInstance())
+ (oldVersion, oldVersion.copy(generatorOutput = newOutput))
+
+ case oldVersion @ Window(_, windowExpressions, _, _, child)
+ if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
+ .nonEmpty =>
+ (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions)))
+ }
+ // Only handle first case, others will be fixed on the next pass.
+ .headOption match {
+ case None =>
+ /*
+ * No result implies that there is a logical plan node that produces new references
+ * that this rule cannot handle. When that is the case, there must be another rule
+ * that resolves these conflicts. Otherwise, the analysis will fail.
+ */
+ right
+ case Some((oldRelation, newRelation)) =>
+ val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
+ val newRight = right transformUp {
+ case r if r == oldRelation => newRelation
+ } transformUp {
+ case other => other transformExpressions {
+ case a: Attribute => attributeRewrites.get(a).getOrElse(a)
+ }
+ }
+ newRight
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p
@@ -388,80 +449,43 @@ class Analyzer(
.map(_.asInstanceOf[NamedExpression])
a.copy(aggregateExpressions = expanded)
- // Special handling for cases when self-join introduce duplicate expression ids.
- case j @ Join(left, right, _, _) if !j.selfJoinResolved =>
- val conflictingAttributes = left.outputSet.intersect(right.outputSet)
- logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j")
-
- right.collect {
- // Handle base relations that might appear more than once.
- case oldVersion: MultiInstanceRelation
- if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
- val newVersion = oldVersion.newInstance()
- (oldVersion, newVersion)
-
- // Handle projects that create conflicting aliases.
- case oldVersion @ Project(projectList, _)
- if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
- (oldVersion, oldVersion.copy(projectList = newAliases(projectList)))
-
- case oldVersion @ Aggregate(_, aggregateExpressions, _)
- if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
- (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
-
- case oldVersion: Generate
- if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
- val newOutput = oldVersion.generatorOutput.map(_.newInstance())
- (oldVersion, oldVersion.copy(generatorOutput = newOutput))
-
- case oldVersion @ Window(_, windowExpressions, _, _, child)
- if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
- .nonEmpty =>
- (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions)))
- }
- // Only handle first case, others will be fixed on the next pass.
- .headOption match {
- case None =>
- /*
- * No result implies that there is a logical plan node that produces new references
- * that this rule cannot handle. When that is the case, there must be another rule
- * that resolves these conflicts. Otherwise, the analysis will fail.
- */
- j
- case Some((oldRelation, newRelation)) =>
- val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
- val newRight = right transformUp {
- case r if r == oldRelation => newRelation
- } transformUp {
- case other => other transformExpressions {
- case a: Attribute => attributeRewrites.get(a).getOrElse(a)
- }
- }
- j.copy(right = newRight)
- }
+ // To resolve duplicate expression IDs for Join and Intersect
+ case j @ Join(left, right, _, _) if !j.duplicateResolved =>
+ j.copy(right = dedupRight(left, right))
+ case i @ Intersect(left, right) if !i.duplicateResolved =>
+ i.copy(right = dedupRight(left, right))
// When resolve `SortOrder`s in Sort based on child, don't report errors as
- // we still have chance to resolve it based on grandchild
+ // we still have chance to resolve it based on its descendants
case s @ Sort(ordering, global, child) if child.resolved && !s.resolved =>
- val newOrdering = resolveSortOrders(ordering, child, throws = false)
+ val newOrdering =
+ ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder])
Sort(newOrdering, global, child)
// A special case for Generate, because the output of Generate should not be resolved by
// ResolveReferences. Attributes in the output will be resolved by ResolveGenerate.
case g @ Generate(generator, join, outer, qualifier, output, child)
if child.resolved && !generator.resolved =>
- val newG = generator transformUp {
- case u @ UnresolvedAttribute(nameParts) =>
- withPosition(u) { child.resolve(nameParts, resolver).getOrElse(u) }
- case UnresolvedExtractValue(child, fieldExpr) =>
- ExtractValue(child, fieldExpr, resolver)
- }
+ val newG = resolveExpression(generator, child, throws = true)
if (newG.fastEquals(generator)) {
g
} else {
Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child)
}
+ // A special case for ObjectOperator, because the deserializer expressions in ObjectOperator
+ // should be resolved by their corresponding attributes instead of children's output.
+ case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) =>
+ val deserializerToAttributes = o.deserializers.map {
+ case (deserializer, attributes) => new TreeNodeRef(deserializer) -> attributes
+ }.toMap
+
+ o.transformExpressions {
+ case expr => deserializerToAttributes.get(new TreeNodeRef(expr)).map { attributes =>
+ resolveDeserializer(expr, attributes)
+ }.getOrElse(expr)
+ }
+
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressionsUp {
@@ -476,6 +500,32 @@ class Analyzer(
}
}
+ private def containsUnresolvedDeserializer(exprs: Seq[Expression]): Boolean = {
+ exprs.exists { expr =>
+ !expr.resolved || expr.find(_.isInstanceOf[BoundReference]).isDefined
+ }
+ }
+
+ def resolveDeserializer(
+ deserializer: Expression,
+ attributes: Seq[Attribute]): Expression = {
+ val unbound = deserializer transform {
+ case b: BoundReference => attributes(b.ordinal)
+ }
+
+ resolveExpression(unbound, LocalRelation(attributes), throws = true) transform {
+ case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
+ val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName)
+ if (outer == null) {
+ throw new AnalysisException(
+ s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
+ "access to the scope that this class was defined in.\n" +
+ "Try moving this class out of its parent class.")
+ }
+ n.copy(outerPointer = Some(Literal.fromObject(outer)))
+ }
+ }
+
def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
expressions.map {
case a: Alias => Alias(a.child, a.name)()
@@ -494,23 +544,20 @@ class Analyzer(
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}
- private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
- ordering.map { order =>
- // Resolve SortOrder in one round.
- // If throws == false or the desired attribute doesn't exist
- // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one.
- // Else, throw exception.
- try {
- val newOrder = order transformUp {
- case u @ UnresolvedAttribute(nameParts) =>
- plan.resolve(nameParts, resolver).getOrElse(u)
- case UnresolvedExtractValue(child, fieldName) if child.resolved =>
- ExtractValue(child, fieldName, resolver)
- }
- newOrder.asInstanceOf[SortOrder]
- } catch {
- case a: AnalysisException if !throws => order
+ private def resolveExpression(expr: Expression, plan: LogicalPlan, throws: Boolean = false) = {
+ // Resolve expression in one round.
+ // If throws == false or the desired attribute doesn't exist
+ // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one.
+ // Else, throw exception.
+ try {
+ expr transformUp {
+ case u @ UnresolvedAttribute(nameParts) =>
+ withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) }
+ case UnresolvedExtractValue(child, fieldName) if child.resolved =>
+ ExtractValue(child, fieldName, resolver)
}
+ } catch {
+ case a: AnalysisException if !throws => expr
}
}
@@ -522,38 +569,97 @@ class Analyzer(
*/
object ResolveSortReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case s @ Sort(ordering, global, p @ Project(projectList, child))
- if !s.resolved && p.resolved =>
- val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child)
+ // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
+ case sa @ Sort(_, _, child: Aggregate) => sa
- // If this rule was not a no-op, return the transformed plan, otherwise return the original.
- if (missing.nonEmpty) {
- // Add missing attributes and then project them away after the sort.
- Project(p.output,
- Sort(newOrdering, global,
- Project(projectList ++ missing, child)))
- } else {
- logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}")
+ case s @ Sort(_, _, child) if !s.resolved && child.resolved =>
+ val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child)
+
+ if (missingResolvableAttrs.isEmpty) {
+ val unresolvableAttrs = s.order.filterNot(_.resolved)
+ logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}")
s // Nothing we can do here. Return original plan.
+ } else {
+ // Add the missing attributes into projectList of Project/Window or
+ // aggregateExpressions of Aggregate, if they are in the inputSet
+ // but not in the outputSet of the plan.
+ val newChild = child transformUp {
+ case p: Project =>
+ p.copy(projectList = p.projectList ++
+ missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains))
+ case w: Window =>
+ w.copy(projectList = w.projectList ++
+ missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains))
+ case a: Aggregate =>
+ val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains)
+ val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains)
+ val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs
+ a.copy(aggregateExpressions = newAggregateExpressions)
+ case o => o
+ }
+
+ // Add missing attributes and then project them away after the sort.
+ Project(child.output,
+ Sort(newOrdering, s.global, newChild))
}
}
/**
- * Given a child and a grandchild that are present beneath a sort operator, try to resolve
- * the sort ordering and returns it with a list of attributes that are missing from the
- * child but are present in the grandchild.
+ * Traverse the tree until resolving the sorting attributes
+ * Return all the resolvable missing sorting attributes
+ */
+ @tailrec
+ private def collectResolvableMissingAttrs(
+ ordering: Seq[SortOrder],
+ plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
+ plan match {
+ // Only Windows and Project have projectList-like attribute.
+ case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] =>
+ val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child)
+ // If missingAttrs is non empty, that means we got it and return it;
+ // Otherwise, continue to traverse the tree.
+ if (missingAttrs.nonEmpty) {
+ (newOrdering, missingAttrs)
+ } else {
+ collectResolvableMissingAttrs(ordering, un.child)
+ }
+ case a: Aggregate =>
+ val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child)
+ // For Aggregate, all the order by columns must be specified in group by clauses
+ if (missingAttrs.nonEmpty &&
+ missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) {
+ (newOrdering, missingAttrs)
+ } else {
+ // If missingAttrs is empty, we are unable to resolve any unresolved missing attributes
+ (Seq.empty[SortOrder], Seq.empty[Attribute])
+ }
+ // Jump over the following UnaryNode types
+ // The output of these types is the same as their child's output
+ case _: Distinct |
+ _: Filter |
+ _: RepartitionByExpression =>
+ collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child)
+ // If hitting the other unsupported operators, we are unable to resolve it.
+ case other => (Seq.empty[SortOrder], Seq.empty[Attribute])
+ }
+ }
+
+ /**
+ * Try to resolve the sort ordering and returns it with a list of attributes that are missing
+ * from the plan but are present in the child.
*/
- def resolveAndFindMissing(
+ private def resolveAndFindMissing(
ordering: Seq[SortOrder],
- child: LogicalPlan,
- grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
- val newOrdering = resolveSortOrders(ordering, grandchild, throws = true)
+ plan: LogicalPlan,
+ child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
+ val newOrdering =
+ ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder])
// Construct a set that contains all of the attributes that we need to evaluate the
// ordering.
val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved)
// Figure out which ones are missing from the projection, so that we can add them and
// remove them after the sort.
- val missingInProject = requiredAttributes -- child.output
+ val missingInProject = requiredAttributes -- plan.outputSet
// It is important to return the new SortOrders here, instead of waiting for the standard
// resolving process as adding attributes to the project below can actually introduce
// ambiguity that was not present before.
@@ -708,7 +814,7 @@ class Analyzer(
}
}
- protected def containsAggregate(condition: Expression): Boolean = {
+ def containsAggregate(condition: Expression): Boolean = {
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
}
@@ -872,12 +978,13 @@ class Analyzer(
if (missingExpr.nonEmpty) {
extractedExprBuffer += ne
}
- ne.toAttribute
+ // alias will be cleaned in the rule CleanupAliases
+ ne
case e: Expression if e.foldable =>
e // No need to create an attribute reference if it will be evaluated as a Literal.
case e: Expression =>
// For other expressions, we extract it and replace it with an AttributeReference (with
- // an interal column name, e.g. "_w0").
+ // an internal column name, e.g. "_w0").
val withName = Alias(e, s"_w${extractedExprBuffer.length}")()
extractedExprBuffer += withName
withName.toAttribute
@@ -1159,6 +1266,50 @@ class Analyzer(
}
}
}
+
+ /**
+ * Removes natural joins by calculating output columns based on output from two sides,
+ * Then apply a Project on a normal Join to eliminate natural join.
+ */
+ object ResolveNaturalJoin extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural =>
+ // find common column names from both sides
+ val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
+ val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get)
+ val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get)
+ val joinPairs = leftKeys.zip(rightKeys)
+
+ // Add joinPairs to joinConditions
+ val newCondition = (condition ++ joinPairs.map {
+ case (l, r) => EqualTo(l, r)
+ }).reduceOption(And)
+
+ // columns not in joinPairs
+ val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
+ val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))
+
+ // the output list looks like: join keys, columns from left, columns from right
+ val projectList = joinType match {
+ case LeftOuter =>
+ leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
+ case RightOuter =>
+ rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
+ case FullOuter =>
+ // in full outer join, joinCols should be non-null if there is.
+ val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() }
+ joinedCols ++
+ lUniqueOutput.map(_.withNullability(true)) ++
+ rUniqueOutput.map(_.withNullability(true))
+ case Inner =>
+ rightKeys ++ lUniqueOutput ++ rUniqueOutput
+ case _ =>
+ sys.error("Unsupported natural join type " + joinType)
+ }
+ // use Project to trim unnecessary fields
+ Project(projectList, Join(left, right, joinType, newCondition))
+ }
+ }
}
/**
@@ -1275,7 +1426,7 @@ object ResolveUpCast extends Rule[LogicalPlan] {
fail(child, DateType, walkedTypePath)
case (StringType, to: NumericType) =>
fail(child, to, walkedTypePath)
- case _ => Cast(child, dataType)
+ case _ => Cast(child, dataType.asNullable)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
index a8f89ce6de457..f2f9ec59417ef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -46,6 +46,10 @@ trait Catalog {
def lookupRelation(tableIdent: TableIdentifier, alias: Option[String] = None): LogicalPlan
+ def setCurrentDatabase(databaseName: String): Unit = {
+ throw new UnsupportedOperationException
+ }
+
/**
* Returns tuples of (tableName, isTemporary) for all tables in the given database.
* isTemporary is a Boolean value indicates if a table is a temporary or not.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index f2e78d97442e3..4a2f2b8bc6e4c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -214,9 +214,8 @@ trait CheckAnalysis {
s"""Only a single table generating function is allowed in a SELECT clause, found:
| ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)
- // Special handling for cases when self-join introduce duplicate expression ids.
- case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty =>
- val conflictingAttributes = left.outputSet.intersect(right.outputSet)
+ case j: Join if !j.duplicateResolved =>
+ val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
failAnalysis(
s"""
|Failure when resolving conflicting references in Join:
@@ -224,6 +223,15 @@ trait CheckAnalysis {
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
|""".stripMargin)
+ case i: Intersect if !i.duplicateResolved =>
+ val conflictingAttributes = i.left.outputSet.intersect(i.right.outputSet)
+ failAnalysis(
+ s"""
+ |Failure when resolving conflicting references in Intersect:
+ |$plan
+ |Conflicting attributes: ${conflictingAttributes.mkString(",")}
+ |""".stripMargin)
+
case o if !o.resolved =>
failAnalysis(
s"unresolved operator ${operator.simpleString}")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 957ac89fa530d..57bdb164e1a0d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -347,18 +347,12 @@ object HiveTypeCoercion {
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
- case StddevPop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
- StddevPop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
- case StddevSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
- StddevSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
- case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
- VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
- case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
- VarianceSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
- case Skewness(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
- Skewness(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
- case Kurtosis(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
- Kurtosis(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
+ case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
+ case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
+ case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType))
+ case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType))
+ case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType))
+ case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
new file mode 100644
index 0000000000000..38be61c52a95e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
@@ -0,0 +1,319 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.catalog
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.AnalysisException
+
+
+/**
+ * An in-memory (ephemeral) implementation of the system catalog.
+ *
+ * All public methods should be synchronized for thread-safety.
+ */
+class InMemoryCatalog extends Catalog {
+ import Catalog._
+
+ private class TableDesc(var table: Table) {
+ val partitions = new mutable.HashMap[PartitionSpec, TablePartition]
+ }
+
+ private class DatabaseDesc(var db: Database) {
+ val tables = new mutable.HashMap[String, TableDesc]
+ val functions = new mutable.HashMap[String, Function]
+ }
+
+ private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc]
+
+ private def filterPattern(names: Seq[String], pattern: String): Seq[String] = {
+ val regex = pattern.replaceAll("\\*", ".*").r
+ names.filter { funcName => regex.pattern.matcher(funcName).matches() }
+ }
+
+ private def existsFunction(db: String, funcName: String): Boolean = {
+ assertDbExists(db)
+ catalog(db).functions.contains(funcName)
+ }
+
+ private def existsTable(db: String, table: String): Boolean = {
+ assertDbExists(db)
+ catalog(db).tables.contains(table)
+ }
+
+ private def existsPartition(db: String, table: String, spec: PartitionSpec): Boolean = {
+ assertTableExists(db, table)
+ catalog(db).tables(table).partitions.contains(spec)
+ }
+
+ private def assertDbExists(db: String): Unit = {
+ if (!catalog.contains(db)) {
+ throw new AnalysisException(s"Database $db does not exist")
+ }
+ }
+
+ private def assertFunctionExists(db: String, funcName: String): Unit = {
+ if (!existsFunction(db, funcName)) {
+ throw new AnalysisException(s"Function $funcName does not exist in $db database")
+ }
+ }
+
+ private def assertTableExists(db: String, table: String): Unit = {
+ if (!existsTable(db, table)) {
+ throw new AnalysisException(s"Table $table does not exist in $db database")
+ }
+ }
+
+ private def assertPartitionExists(db: String, table: String, spec: PartitionSpec): Unit = {
+ if (!existsPartition(db, table, spec)) {
+ throw new AnalysisException(s"Partition does not exist in database $db table $table: $spec")
+ }
+ }
+
+ // --------------------------------------------------------------------------
+ // Databases
+ // --------------------------------------------------------------------------
+
+ override def createDatabase(
+ dbDefinition: Database,
+ ignoreIfExists: Boolean): Unit = synchronized {
+ if (catalog.contains(dbDefinition.name)) {
+ if (!ignoreIfExists) {
+ throw new AnalysisException(s"Database ${dbDefinition.name} already exists.")
+ }
+ } else {
+ catalog.put(dbDefinition.name, new DatabaseDesc(dbDefinition))
+ }
+ }
+
+ override def dropDatabase(
+ db: String,
+ ignoreIfNotExists: Boolean,
+ cascade: Boolean): Unit = synchronized {
+ if (catalog.contains(db)) {
+ if (!cascade) {
+ // If cascade is false, make sure the database is empty.
+ if (catalog(db).tables.nonEmpty) {
+ throw new AnalysisException(s"Database $db is not empty. One or more tables exist.")
+ }
+ if (catalog(db).functions.nonEmpty) {
+ throw new AnalysisException(s"Database $db is not empty. One or more functions exist.")
+ }
+ }
+ // Remove the database.
+ catalog.remove(db)
+ } else {
+ if (!ignoreIfNotExists) {
+ throw new AnalysisException(s"Database $db does not exist")
+ }
+ }
+ }
+
+ override def alterDatabase(db: String, dbDefinition: Database): Unit = synchronized {
+ assertDbExists(db)
+ assert(db == dbDefinition.name)
+ catalog(db).db = dbDefinition
+ }
+
+ override def getDatabase(db: String): Database = synchronized {
+ assertDbExists(db)
+ catalog(db).db
+ }
+
+ override def listDatabases(): Seq[String] = synchronized {
+ catalog.keySet.toSeq
+ }
+
+ override def listDatabases(pattern: String): Seq[String] = synchronized {
+ filterPattern(listDatabases(), pattern)
+ }
+
+ // --------------------------------------------------------------------------
+ // Tables
+ // --------------------------------------------------------------------------
+
+ override def createTable(
+ db: String,
+ tableDefinition: Table,
+ ignoreIfExists: Boolean): Unit = synchronized {
+ assertDbExists(db)
+ if (existsTable(db, tableDefinition.name)) {
+ if (!ignoreIfExists) {
+ throw new AnalysisException(s"Table ${tableDefinition.name} already exists in $db database")
+ }
+ } else {
+ catalog(db).tables.put(tableDefinition.name, new TableDesc(tableDefinition))
+ }
+ }
+
+ override def dropTable(
+ db: String,
+ table: String,
+ ignoreIfNotExists: Boolean): Unit = synchronized {
+ assertDbExists(db)
+ if (existsTable(db, table)) {
+ catalog(db).tables.remove(table)
+ } else {
+ if (!ignoreIfNotExists) {
+ throw new AnalysisException(s"Table $table does not exist in $db database")
+ }
+ }
+ }
+
+ override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized {
+ assertTableExists(db, oldName)
+ val oldDesc = catalog(db).tables(oldName)
+ oldDesc.table = oldDesc.table.copy(name = newName)
+ catalog(db).tables.put(newName, oldDesc)
+ catalog(db).tables.remove(oldName)
+ }
+
+ override def alterTable(db: String, table: String, tableDefinition: Table): Unit = synchronized {
+ assertTableExists(db, table)
+ assert(table == tableDefinition.name)
+ catalog(db).tables(table).table = tableDefinition
+ }
+
+ override def getTable(db: String, table: String): Table = synchronized {
+ assertTableExists(db, table)
+ catalog(db).tables(table).table
+ }
+
+ override def listTables(db: String): Seq[String] = synchronized {
+ assertDbExists(db)
+ catalog(db).tables.keySet.toSeq
+ }
+
+ override def listTables(db: String, pattern: String): Seq[String] = synchronized {
+ assertDbExists(db)
+ filterPattern(listTables(db), pattern)
+ }
+
+ // --------------------------------------------------------------------------
+ // Partitions
+ // --------------------------------------------------------------------------
+
+ override def createPartitions(
+ db: String,
+ table: String,
+ parts: Seq[TablePartition],
+ ignoreIfExists: Boolean): Unit = synchronized {
+ assertTableExists(db, table)
+ val existingParts = catalog(db).tables(table).partitions
+ if (!ignoreIfExists) {
+ val dupSpecs = parts.collect { case p if existingParts.contains(p.spec) => p.spec }
+ if (dupSpecs.nonEmpty) {
+ val dupSpecsStr = dupSpecs.mkString("\n===\n")
+ throw new AnalysisException(
+ s"The following partitions already exist in database $db table $table:\n$dupSpecsStr")
+ }
+ }
+ parts.foreach { p => existingParts.put(p.spec, p) }
+ }
+
+ override def dropPartitions(
+ db: String,
+ table: String,
+ partSpecs: Seq[PartitionSpec],
+ ignoreIfNotExists: Boolean): Unit = synchronized {
+ assertTableExists(db, table)
+ val existingParts = catalog(db).tables(table).partitions
+ if (!ignoreIfNotExists) {
+ val missingSpecs = partSpecs.collect { case s if !existingParts.contains(s) => s }
+ if (missingSpecs.nonEmpty) {
+ val missingSpecsStr = missingSpecs.mkString("\n===\n")
+ throw new AnalysisException(
+ s"The following partitions do not exist in database $db table $table:\n$missingSpecsStr")
+ }
+ }
+ partSpecs.foreach(existingParts.remove)
+ }
+
+ override def alterPartition(
+ db: String,
+ table: String,
+ spec: Map[String, String],
+ newPart: TablePartition): Unit = synchronized {
+ assertPartitionExists(db, table, spec)
+ val existingParts = catalog(db).tables(table).partitions
+ if (spec != newPart.spec) {
+ // Also a change in specs; remove the old one and add the new one back
+ existingParts.remove(spec)
+ }
+ existingParts.put(newPart.spec, newPart)
+ }
+
+ override def getPartition(
+ db: String,
+ table: String,
+ spec: Map[String, String]): TablePartition = synchronized {
+ assertPartitionExists(db, table, spec)
+ catalog(db).tables(table).partitions(spec)
+ }
+
+ override def listPartitions(db: String, table: String): Seq[TablePartition] = synchronized {
+ assertTableExists(db, table)
+ catalog(db).tables(table).partitions.values.toSeq
+ }
+
+ // --------------------------------------------------------------------------
+ // Functions
+ // --------------------------------------------------------------------------
+
+ override def createFunction(
+ db: String,
+ func: Function,
+ ignoreIfExists: Boolean): Unit = synchronized {
+ assertDbExists(db)
+ if (existsFunction(db, func.name)) {
+ if (!ignoreIfExists) {
+ throw new AnalysisException(s"Function $func already exists in $db database")
+ }
+ } else {
+ catalog(db).functions.put(func.name, func)
+ }
+ }
+
+ override def dropFunction(db: String, funcName: String): Unit = synchronized {
+ assertFunctionExists(db, funcName)
+ catalog(db).functions.remove(funcName)
+ }
+
+ override def alterFunction(
+ db: String,
+ funcName: String,
+ funcDefinition: Function): Unit = synchronized {
+ assertFunctionExists(db, funcName)
+ if (funcName != funcDefinition.name) {
+ // Also a rename; remove the old one and add the new one back
+ catalog(db).functions.remove(funcName)
+ }
+ catalog(db).functions.put(funcDefinition.name, funcDefinition)
+ }
+
+ override def getFunction(db: String, funcName: String): Function = synchronized {
+ assertFunctionExists(db, funcName)
+ catalog(db).functions(funcName)
+ }
+
+ override def listFunctions(db: String, pattern: String): Seq[String] = synchronized {
+ assertDbExists(db)
+ filterPattern(catalog(db).functions.keysIterator.toSeq, pattern)
+ }
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
new file mode 100644
index 0000000000000..56aaa6bc6c2e9
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.catalog
+
+import org.apache.spark.sql.AnalysisException
+
+
+/**
+ * Interface for the system catalog (of columns, partitions, tables, and databases).
+ *
+ * This is only used for non-temporary items, and implementations must be thread-safe as they
+ * can be accessed in multiple threads.
+ *
+ * Implementations should throw [[AnalysisException]] when table or database don't exist.
+ */
+abstract class Catalog {
+ import Catalog._
+
+ // --------------------------------------------------------------------------
+ // Databases
+ // --------------------------------------------------------------------------
+
+ def createDatabase(dbDefinition: Database, ignoreIfExists: Boolean): Unit
+
+ def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit
+
+ /**
+ * Alter an existing database. This operation does not support renaming.
+ */
+ def alterDatabase(db: String, dbDefinition: Database): Unit
+
+ def getDatabase(db: String): Database
+
+ def listDatabases(): Seq[String]
+
+ def listDatabases(pattern: String): Seq[String]
+
+ // --------------------------------------------------------------------------
+ // Tables
+ // --------------------------------------------------------------------------
+
+ def createTable(db: String, tableDefinition: Table, ignoreIfExists: Boolean): Unit
+
+ def dropTable(db: String, table: String, ignoreIfNotExists: Boolean): Unit
+
+ def renameTable(db: String, oldName: String, newName: String): Unit
+
+ /**
+ * Alter an existing table. This operation does not support renaming.
+ */
+ def alterTable(db: String, table: String, tableDefinition: Table): Unit
+
+ def getTable(db: String, table: String): Table
+
+ def listTables(db: String): Seq[String]
+
+ def listTables(db: String, pattern: String): Seq[String]
+
+ // --------------------------------------------------------------------------
+ // Partitions
+ // --------------------------------------------------------------------------
+
+ def createPartitions(
+ db: String,
+ table: String,
+ parts: Seq[TablePartition],
+ ignoreIfExists: Boolean): Unit
+
+ def dropPartitions(
+ db: String,
+ table: String,
+ parts: Seq[PartitionSpec],
+ ignoreIfNotExists: Boolean): Unit
+
+ /**
+ * Alter an existing table partition and optionally override its spec.
+ */
+ def alterPartition(
+ db: String,
+ table: String,
+ spec: PartitionSpec,
+ newPart: TablePartition): Unit
+
+ def getPartition(db: String, table: String, spec: PartitionSpec): TablePartition
+
+ // TODO: support listing by pattern
+ def listPartitions(db: String, table: String): Seq[TablePartition]
+
+ // --------------------------------------------------------------------------
+ // Functions
+ // --------------------------------------------------------------------------
+
+ def createFunction(db: String, funcDefinition: Function, ignoreIfExists: Boolean): Unit
+
+ def dropFunction(db: String, funcName: String): Unit
+
+ /**
+ * Alter an existing function and optionally override its name.
+ */
+ def alterFunction(db: String, funcName: String, funcDefinition: Function): Unit
+
+ def getFunction(db: String, funcName: String): Function
+
+ def listFunctions(db: String, pattern: String): Seq[String]
+
+}
+
+
+/**
+ * A function defined in the catalog.
+ *
+ * @param name name of the function
+ * @param className fully qualified class name, e.g. "org.apache.spark.util.MyFunc"
+ */
+case class Function(
+ name: String,
+ className: String
+)
+
+
+/**
+ * Storage format, used to describe how a partition or a table is stored.
+ */
+case class StorageFormat(
+ locationUri: String,
+ inputFormat: String,
+ outputFormat: String,
+ serde: String,
+ serdeProperties: Map[String, String]
+)
+
+
+/**
+ * A column in a table.
+ */
+case class Column(
+ name: String,
+ dataType: String,
+ nullable: Boolean,
+ comment: String
+)
+
+
+/**
+ * A partition (Hive style) defined in the catalog.
+ *
+ * @param spec partition spec values indexed by column name
+ * @param storage storage format of the partition
+ */
+case class TablePartition(
+ spec: Catalog.PartitionSpec,
+ storage: StorageFormat
+)
+
+
+/**
+ * A table defined in the catalog.
+ *
+ * Note that Hive's metastore also tracks skewed columns. We should consider adding that in the
+ * future once we have a better understanding of how we want to handle skewed columns.
+ */
+case class Table(
+ name: String,
+ description: String,
+ schema: Seq[Column],
+ partitionColumns: Seq[Column],
+ sortColumns: Seq[Column],
+ storage: StorageFormat,
+ numBuckets: Int,
+ properties: Map[String, String],
+ tableType: String,
+ createTime: Long,
+ lastAccessTime: Long,
+ viewOriginalText: Option[String],
+ viewText: Option[String]) {
+
+ require(tableType == "EXTERNAL_TABLE" || tableType == "INDEX_TABLE" ||
+ tableType == "MANAGED_TABLE" || tableType == "VIRTUAL_VIEW")
+}
+
+
+/**
+ * A database defined in the catalog.
+ */
+case class Database(
+ name: String,
+ description: String,
+ locationUri: String,
+ properties: Map[String, String]
+)
+
+
+object Catalog {
+ /**
+ * Specifications of a table partition indexed by column name.
+ */
+ type PartitionSpec = Map[String, String]
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 64832dc114e67..58f6d0eb9e929 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -50,7 +50,7 @@ object ExpressionEncoder {
val cls = mirror.runtimeClass(typeTag[T].tpe)
val flat = !classOf[Product].isAssignableFrom(cls)
- val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true)
+ val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false)
val toRowExpression = ScalaReflection.extractorsFor[T](inputObject)
val fromRowExpression = ScalaReflection.constructorFor[T]
@@ -257,12 +257,10 @@ case class ExpressionEncoder[T](
}
/**
- * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
- * given schema.
+ * Validates `fromRowExpression` to make sure it can be resolved by given schema, and produce
+ * friendly error messages to explain why it fails to resolve if there is something wrong.
*/
- def resolve(
- schema: Seq[Attribute],
- outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
+ def validate(schema: Seq[Attribute]): Unit = {
def fail(st: StructType, maxOrdinal: Int): Unit = {
throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " +
"but failed as the number of fields does not line up.\n" +
@@ -270,6 +268,8 @@ case class ExpressionEncoder[T](
" - Target schema: " + this.schema.simpleString)
}
+ // If this is a tuple encoder or tupled encoder, which means its leaf nodes are all
+ // `BoundReference`, make sure their ordinals are all valid.
var maxOrdinal = -1
fromRowExpression.foreach {
case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal
@@ -279,6 +279,10 @@ case class ExpressionEncoder[T](
fail(StructType.fromAttributes(schema), maxOrdinal)
}
+ // If we have nested tuple, the `fromRowExpression` will contains `GetStructField` instead of
+ // `UnresolvedExtractValue`, so we need to check if their ordinals are all valid.
+ // Note that, `BoundReference` contains the expected type, but here we need the actual type, so
+ // we unbound it by the given `schema` and propagate the actual type to `GetStructField`.
val unbound = fromRowExpression transform {
case b: BoundReference => schema(b.ordinal)
}
@@ -299,28 +303,24 @@ case class ExpressionEncoder[T](
fail(schema, maxOrdinal)
}
}
+ }
- val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
+ /**
+ * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
+ * given schema.
+ */
+ def resolve(
+ schema: Seq[Attribute],
+ outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
+ val deserializer = SimpleAnalyzer.ResolveReferences.resolveDeserializer(
+ fromRowExpression, schema)
+
+ // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check
+ // analysis, go through optimizer, etc.
+ val plan = Project(Alias(deserializer, "")() :: Nil, LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
SimpleAnalyzer.checkAnalysis(analyzedPlan)
- val optimizedPlan = SimplifyCasts(analyzedPlan)
-
- // In order to construct instances of inner classes (for example those declared in a REPL cell),
- // we need an instance of the outer scope. This rule substitues those outer objects into
- // expressions that are missing them by looking up the name in the SQLContexts `outerScopes`
- // registry.
- copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform {
- case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
- val outer = outerScopes.get(n.cls.getDeclaringClass.getName)
- if (outer == null) {
- throw new AnalysisException(
- s"Unable to generate an encoder for inner class `${n.cls.getName}` without access " +
- s"to the scope that this class was defined in. " + "" +
- "Try moving this class out of its parent class.")
- }
-
- n.copy(outerPointer = Some(Literal.fromObject(outer)))
- })
+ copy(fromRowExpression = SimplifyCasts(analyzedPlan).expressions.head.children.head)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 89d40b3b2c141..d8f755a39c7ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -154,7 +154,7 @@ object RowEncoder {
If(
IsNull(field),
Literal.create(null, externalDataTypeFor(f.dataType)),
- constructorFor(BoundReference(i, f.dataType, f.nullable))
+ constructorFor(field)
)
}
CreateExternalRow(fields)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index db17ba7c84ffc..c73b2f8f2a316 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -103,8 +103,12 @@ abstract class Expression extends TreeNode[Expression] {
val value = ctx.freshName("value")
val ve = ExprCode("", isNull, value)
ve.code = genCode(ctx, ve)
- // Add `this` in the comment.
- ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
+ if (ve.code != "") {
+ // Add `this` in the comment.
+ ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
+ } else {
+ ve
+ }
}
}
@@ -320,7 +324,7 @@ abstract class UnaryExpression extends Expression {
/**
* Called by unary expressions to generate a code block that returns null if its parent returns
- * null, and if not not null, use `f` to generate the expression.
+ * null, and if not null, use `f` to generate the expression.
*
* As an example, the following does a boolean inversion (i.e. NOT).
* {{{
@@ -340,7 +344,7 @@ abstract class UnaryExpression extends Expression {
/**
* Called by unary expressions to generate a code block that returns null if its parent returns
- * null, and if not not null, use `f` to generate the expression.
+ * null, and if not null, use `f` to generate the expression.
*
* @param f function that accepts the non-null evaluation result name of child and returns Java
* code to compute the output.
@@ -349,20 +353,23 @@ abstract class UnaryExpression extends Expression {
ctx: CodegenContext,
ev: ExprCode,
f: String => String): String = {
- val eval = child.gen(ctx)
+ val childGen = child.gen(ctx)
+ val resultCode = f(childGen.value)
+
if (nullable) {
- eval.code + s"""
- boolean ${ev.isNull} = ${eval.isNull};
+ val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode)
+ s"""
+ ${childGen.code}
+ boolean ${ev.isNull} = ${childGen.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${eval.isNull}) {
- ${f(eval.value)}
- }
+ $nullSafeEval
"""
} else {
ev.isNull = "false"
- eval.code + s"""
+ s"""
+ ${childGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- ${f(eval.value)}
+ $resultCode
"""
}
}
@@ -440,29 +447,31 @@ abstract class BinaryExpression extends Expression {
ctx: CodegenContext,
ev: ExprCode,
f: (String, String) => String): String = {
- val eval1 = left.gen(ctx)
- val eval2 = right.gen(ctx)
- val resultCode = f(eval1.value, eval2.value)
+ val leftGen = left.gen(ctx)
+ val rightGen = right.gen(ctx)
+ val resultCode = f(leftGen.value, rightGen.value)
+
if (nullable) {
+ val nullSafeEval =
+ leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) {
+ rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) {
+ s"""
+ ${ev.isNull} = false; // resultCode could change nullability.
+ $resultCode
+ """
+ }
+ }
+
s"""
- ${eval1.code}
- boolean ${ev.isNull} = ${eval1.isNull};
+ boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- ${eval2.code}
- if (!${eval2.isNull}) {
- $resultCode
- } else {
- ${ev.isNull} = true;
- }
- }
+ $nullSafeEval
"""
-
} else {
ev.isNull = "false"
s"""
- ${eval1.code}
- ${eval2.code}
+ ${leftGen.code}
+ ${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$resultCode
"""
@@ -527,7 +536,7 @@ abstract class TernaryExpression extends Expression {
/**
* Default behavior of evaluation according to the default nullability of TernaryExpression.
- * If subclass of BinaryExpression override nullable, probably should also override this.
+ * If subclass of TernaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val exprs = children
@@ -553,11 +562,11 @@ abstract class TernaryExpression extends Expression {
sys.error(s"BinaryExpressions must override either eval or nullSafeEval")
/**
- * Short hand for generating binary evaluation code.
+ * Short hand for generating ternary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
- * @param f accepts two variable names and returns Java code to compute the output.
+ * @param f accepts three variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodegenContext,
@@ -569,41 +578,46 @@ abstract class TernaryExpression extends Expression {
}
/**
- * Short hand for generating binary evaluation code.
+ * Short hand for generating ternary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
- * @param f function that accepts the 2 non-null evaluation result names of children
+ * @param f function that accepts the 3 non-null evaluation result names of children
* and returns Java code to compute the output.
*/
protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String) => String): String = {
- val evals = children.map(_.gen(ctx))
- val resultCode = f(evals(0).value, evals(1).value, evals(2).value)
+ val leftGen = children(0).gen(ctx)
+ val midGen = children(1).gen(ctx)
+ val rightGen = children(2).gen(ctx)
+ val resultCode = f(leftGen.value, midGen.value, rightGen.value)
+
if (nullable) {
+ val nullSafeEval =
+ leftGen.code + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) {
+ midGen.code + ctx.nullSafeExec(children(1).nullable, midGen.isNull) {
+ rightGen.code + ctx.nullSafeExec(children(2).nullable, rightGen.isNull) {
+ s"""
+ ${ev.isNull} = false; // resultCode could change nullability.
+ $resultCode
+ """
+ }
+ }
+ }
+
s"""
- ${evals(0).code}
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${evals(0).isNull}) {
- ${evals(1).code}
- if (!${evals(1).isNull}) {
- ${evals(2).code}
- if (!${evals(2).isNull}) {
- ${ev.isNull} = false; // resultCode could change nullability
- $resultCode
- }
- }
- }
+ $nullSafeEval
"""
} else {
ev.isNull = "false"
s"""
- ${evals(0).code}
- ${evals(1).code}
- ${evals(2).code}
+ ${leftGen.code}
+ ${midGen.code}
+ ${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$resultCode
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index 30f602227b17d..9d2db45144817 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -17,10 +17,8 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
/**
@@ -44,7 +42,7 @@ import org.apache.spark.sql.types._
*
* @param child to compute central moments of.
*/
-abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable {
+abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate {
/**
* The central moment order to be computed.
@@ -52,178 +50,161 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w
protected def momentOrder: Int
override def children: Seq[Expression] = Seq(child)
-
override def nullable: Boolean = true
-
override def dataType: DataType = DoubleType
+ override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)
- override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+ protected val n = AttributeReference("n", DoubleType, nullable = false)()
+ protected val avg = AttributeReference("avg", DoubleType, nullable = false)()
+ protected val m2 = AttributeReference("m2", DoubleType, nullable = false)()
+ protected val m3 = AttributeReference("m3", DoubleType, nullable = false)()
+ protected val m4 = AttributeReference("m4", DoubleType, nullable = false)()
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName")
+ private def trimHigherOrder[T](expressions: Seq[T]) = expressions.take(momentOrder + 1)
- override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
+ override val aggBufferAttributes = trimHigherOrder(Seq(n, avg, m2, m3, m4))
- /**
- * Size of aggregation buffer.
- */
- private[this] val bufferSize = 5
+ override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0))
- override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i =>
- AttributeReference(s"M$i", DoubleType)()
+ override val updateExpressions: Seq[Expression] = {
+ val newN = n + Literal(1.0)
+ val delta = child - avg
+ val deltaN = delta / newN
+ val newAvg = avg + deltaN
+ val newM2 = m2 + delta * (delta - deltaN)
+
+ val delta2 = delta * delta
+ val deltaN2 = deltaN * deltaN
+ val newM3 = if (momentOrder >= 3) {
+ m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2)
+ } else {
+ Literal(0.0)
+ }
+ val newM4 = if (momentOrder >= 4) {
+ m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 +
+ delta * (delta * delta2 - deltaN * deltaN2)
+ } else {
+ Literal(0.0)
+ }
+
+ trimHigherOrder(Seq(
+ If(IsNull(child), n, newN),
+ If(IsNull(child), avg, newAvg),
+ If(IsNull(child), m2, newM2),
+ If(IsNull(child), m3, newM3),
+ If(IsNull(child), m4, newM4)
+ ))
}
- // Note: although this simply copies aggBufferAttributes, this common code can not be placed
- // in the superclass because that will lead to initialization ordering issues.
- override val inputAggBufferAttributes: Seq[AttributeReference] =
- aggBufferAttributes.map(_.newInstance())
-
- // buffer offsets
- private[this] val nOffset = mutableAggBufferOffset
- private[this] val meanOffset = mutableAggBufferOffset + 1
- private[this] val secondMomentOffset = mutableAggBufferOffset + 2
- private[this] val thirdMomentOffset = mutableAggBufferOffset + 3
- private[this] val fourthMomentOffset = mutableAggBufferOffset + 4
-
- // frequently used values for online updates
- private[this] var delta = 0.0
- private[this] var deltaN = 0.0
- private[this] var delta2 = 0.0
- private[this] var deltaN2 = 0.0
- private[this] var n = 0.0
- private[this] var mean = 0.0
- private[this] var m2 = 0.0
- private[this] var m3 = 0.0
- private[this] var m4 = 0.0
+ override val mergeExpressions: Seq[Expression] = {
- /**
- * Initialize all moments to zero.
- */
- override def initialize(buffer: MutableRow): Unit = {
- for (aggIndex <- 0 until bufferSize) {
- buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0)
+ val n1 = n.left
+ val n2 = n.right
+ val newN = n1 + n2
+ val delta = avg.right - avg.left
+ val deltaN = If(newN === Literal(0.0), Literal(0.0), delta / newN)
+ val newAvg = avg.left + deltaN * n2
+
+ // higher order moments computed according to:
+ // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
+ val newM2 = m2.left + m2.right + delta * deltaN * n1 * n2
+ // `m3.right` is not available if momentOrder < 3
+ val newM3 = if (momentOrder >= 3) {
+ m3.left + m3.right + deltaN * deltaN * delta * n1 * n2 * (n1 - n2) +
+ Literal(3.0) * deltaN * (n1 * m2.right - n2 * m2.left)
+ } else {
+ Literal(0.0)
}
+ // `m4.right` is not available if momentOrder < 4
+ val newM4 = if (momentOrder >= 4) {
+ m4.left + m4.right +
+ deltaN * deltaN * deltaN * delta * n1 * n2 * (n1 * n1 - n1 * n2 + n2 * n2) +
+ Literal(6.0) * deltaN * deltaN * (n1 * n1 * m2.right + n2 * n2 * m2.left) +
+ Literal(4.0) * deltaN * (n1 * m3.right - n2 * m3.left)
+ } else {
+ Literal(0.0)
+ }
+
+ trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4))
}
+}
- /**
- * Update the central moments buffer.
- */
- override def update(buffer: MutableRow, input: InternalRow): Unit = {
- val v = Cast(child, DoubleType).eval(input)
- if (v != null) {
- val updateValue = v match {
- case d: Double => d
- }
-
- n = buffer.getDouble(nOffset)
- mean = buffer.getDouble(meanOffset)
-
- n += 1.0
- buffer.setDouble(nOffset, n)
- delta = updateValue - mean
- deltaN = delta / n
- mean += deltaN
- buffer.setDouble(meanOffset, mean)
-
- if (momentOrder >= 2) {
- m2 = buffer.getDouble(secondMomentOffset)
- m2 += delta * (delta - deltaN)
- buffer.setDouble(secondMomentOffset, m2)
- }
-
- if (momentOrder >= 3) {
- delta2 = delta * delta
- deltaN2 = deltaN * deltaN
- m3 = buffer.getDouble(thirdMomentOffset)
- m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2)
- buffer.setDouble(thirdMomentOffset, m3)
- }
-
- if (momentOrder >= 4) {
- m4 = buffer.getDouble(fourthMomentOffset)
- m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 +
- delta * (delta * delta2 - deltaN * deltaN2)
- buffer.setDouble(fourthMomentOffset, m4)
- }
- }
+// Compute the population standard deviation of a column
+case class StddevPop(child: Expression) extends CentralMomentAgg(child) {
+
+ override protected def momentOrder = 2
+
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ Sqrt(m2 / n))
}
- /**
- * Merge two central moment buffers.
- */
- override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
- val n1 = buffer1.getDouble(nOffset)
- val n2 = buffer2.getDouble(inputAggBufferOffset)
- val mean1 = buffer1.getDouble(meanOffset)
- val mean2 = buffer2.getDouble(inputAggBufferOffset + 1)
+ override def prettyName: String = "stddev_pop"
+}
+
+// Compute the sample standard deviation of a column
+case class StddevSamp(child: Expression) extends CentralMomentAgg(child) {
+
+ override protected def momentOrder = 2
- var secondMoment1 = 0.0
- var secondMoment2 = 0.0
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ If(n === Literal(1.0), Literal(Double.NaN),
+ Sqrt(m2 / (n - Literal(1.0)))))
+ }
- var thirdMoment1 = 0.0
- var thirdMoment2 = 0.0
+ override def prettyName: String = "stddev_samp"
+}
- var fourthMoment1 = 0.0
- var fourthMoment2 = 0.0
+// Compute the population variance of a column
+case class VariancePop(child: Expression) extends CentralMomentAgg(child) {
- n = n1 + n2
- buffer1.setDouble(nOffset, n)
- delta = mean2 - mean1
- deltaN = if (n == 0.0) 0.0 else delta / n
- mean = mean1 + deltaN * n2
- buffer1.setDouble(mutableAggBufferOffset + 1, mean)
+ override protected def momentOrder = 2
- // higher order moments computed according to:
- // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
- if (momentOrder >= 2) {
- secondMoment1 = buffer1.getDouble(secondMomentOffset)
- secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2)
- m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2
- buffer1.setDouble(secondMomentOffset, m2)
- }
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ m2 / n)
+ }
- if (momentOrder >= 3) {
- thirdMoment1 = buffer1.getDouble(thirdMomentOffset)
- thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3)
- m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 *
- (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1)
- buffer1.setDouble(thirdMomentOffset, m3)
- }
+ override def prettyName: String = "var_pop"
+}
- if (momentOrder >= 4) {
- fourthMoment1 = buffer1.getDouble(fourthMomentOffset)
- fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4)
- m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 *
- n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 *
- (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) +
- 4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1)
- buffer1.setDouble(fourthMomentOffset, m4)
- }
+// Compute the sample variance of a column
+case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) {
+
+ override protected def momentOrder = 2
+
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ If(n === Literal(1.0), Literal(Double.NaN),
+ m2 / (n - Literal(1.0))))
}
- /**
- * Compute aggregate statistic from sufficient moments.
- * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized)
- * needed to compute the aggregate stat.
- */
- def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any
-
- override final def eval(buffer: InternalRow): Any = {
- val n = buffer.getDouble(nOffset)
- val mean = buffer.getDouble(meanOffset)
- val moments = Array.ofDim[Double](momentOrder + 1)
- moments(0) = 1.0
- moments(1) = 0.0
- if (momentOrder >= 2) {
- moments(2) = buffer.getDouble(secondMomentOffset)
- }
- if (momentOrder >= 3) {
- moments(3) = buffer.getDouble(thirdMomentOffset)
- }
- if (momentOrder >= 4) {
- moments(4) = buffer.getDouble(fourthMomentOffset)
- }
+ override def prettyName: String = "var_samp"
+}
+
+case class Skewness(child: Expression) extends CentralMomentAgg(child) {
+
+ override def prettyName: String = "skewness"
+
+ override protected def momentOrder = 3
- getStatistic(n, mean, moments)
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ If(m2 === Literal(0.0), Literal(Double.NaN),
+ Sqrt(n) * m3 / Sqrt(m2 * m2 * m2)))
}
}
+
+case class Kurtosis(child: Expression) extends CentralMomentAgg(child) {
+
+ override protected def momentOrder = 4
+
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ If(m2 === Literal(0.0), Literal(Double.NaN),
+ n * m4 / (m2 * m2) - Literal(3.0)))
+ }
+
+ override def prettyName: String = "kurtosis"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index d25f3335ffd93..e6b8214ef25e9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@@ -29,165 +28,70 @@ import org.apache.spark.sql.types._
* Definition of Pearson correlation can be found at
* http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
*/
-case class Corr(
- left: Expression,
- right: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends ImperativeAggregate {
-
- def this(left: Expression, right: Expression) =
- this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
-
- override def children: Seq[Expression] = Seq(left, right)
+case class Corr(x: Expression, y: Expression) extends DeclarativeAggregate {
+ override def children: Seq[Expression] = Seq(x, y)
override def nullable: Boolean = true
-
override def dataType: DataType = DoubleType
-
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
- override def checkInputDataTypes(): TypeCheckResult = {
- if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) {
- TypeCheckResult.TypeCheckSuccess
- } else {
- TypeCheckResult.TypeCheckFailure(
- s"corr requires that both arguments are double type, " +
- s"not (${left.dataType}, ${right.dataType}).")
- }
+ protected val n = AttributeReference("n", DoubleType, nullable = false)()
+ protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)()
+ protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)()
+ protected val ck = AttributeReference("ck", DoubleType, nullable = false)()
+ protected val xMk = AttributeReference("xMk", DoubleType, nullable = false)()
+ protected val yMk = AttributeReference("yMk", DoubleType, nullable = false)()
+
+ override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, yAvg, ck, xMk, yMk)
+
+ override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0))
+
+ override val updateExpressions: Seq[Expression] = {
+ val newN = n + Literal(1.0)
+ val dx = x - xAvg
+ val dxN = dx / newN
+ val dy = y - yAvg
+ val dyN = dy / newN
+ val newXAvg = xAvg + dxN
+ val newYAvg = yAvg + dyN
+ val newCk = ck + dx * (y - newYAvg)
+ val newXMk = xMk + dx * (x - newXAvg)
+ val newYMk = yMk + dy * (y - newYAvg)
+
+ val isNull = IsNull(x) || IsNull(y)
+ Seq(
+ If(isNull, n, newN),
+ If(isNull, xAvg, newXAvg),
+ If(isNull, yAvg, newYAvg),
+ If(isNull, ck, newCk),
+ If(isNull, xMk, newXMk),
+ If(isNull, yMk, newYMk)
+ )
}
- override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
-
- override def inputAggBufferAttributes: Seq[AttributeReference] = {
- aggBufferAttributes.map(_.newInstance())
+ override val mergeExpressions: Seq[Expression] = {
+
+ val n1 = n.left
+ val n2 = n.right
+ val newN = n1 + n2
+ val dx = xAvg.right - xAvg.left
+ val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN)
+ val dy = yAvg.right - yAvg.left
+ val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN)
+ val newXAvg = xAvg.left + dxN * n2
+ val newYAvg = yAvg.left + dyN * n2
+ val newCk = ck.left + ck.right + dx * dyN * n1 * n2
+ val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2
+ val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2
+
+ Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk)
}
- override val aggBufferAttributes: Seq[AttributeReference] = Seq(
- AttributeReference("xAvg", DoubleType)(),
- AttributeReference("yAvg", DoubleType)(),
- AttributeReference("Ck", DoubleType)(),
- AttributeReference("MkX", DoubleType)(),
- AttributeReference("MkY", DoubleType)(),
- AttributeReference("count", LongType)())
-
- // Local cache of mutableAggBufferOffset(s) that will be used in update and merge
- private[this] val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1
- private[this] val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2
- private[this] val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3
- private[this] val mutableAggBufferOffsetPlus4 = mutableAggBufferOffset + 4
- private[this] val mutableAggBufferOffsetPlus5 = mutableAggBufferOffset + 5
-
- // Local cache of inputAggBufferOffset(s) that will be used in update and merge
- private[this] val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1
- private[this] val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2
- private[this] val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3
- private[this] val inputAggBufferOffsetPlus4 = inputAggBufferOffset + 4
- private[this] val inputAggBufferOffsetPlus5 = inputAggBufferOffset + 5
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
-
- override def initialize(buffer: MutableRow): Unit = {
- buffer.setDouble(mutableAggBufferOffset, 0.0)
- buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0)
- buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0)
- buffer.setDouble(mutableAggBufferOffsetPlus3, 0.0)
- buffer.setDouble(mutableAggBufferOffsetPlus4, 0.0)
- buffer.setLong(mutableAggBufferOffsetPlus5, 0L)
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ If(n === Literal(1.0), Literal(Double.NaN),
+ ck / Sqrt(xMk * yMk)))
}
- override def update(buffer: MutableRow, input: InternalRow): Unit = {
- val leftEval = left.eval(input)
- val rightEval = right.eval(input)
-
- if (leftEval != null && rightEval != null) {
- val x = leftEval.asInstanceOf[Double]
- val y = rightEval.asInstanceOf[Double]
-
- var xAvg = buffer.getDouble(mutableAggBufferOffset)
- var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1)
- var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
- var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3)
- var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
- var count = buffer.getLong(mutableAggBufferOffsetPlus5)
-
- val deltaX = x - xAvg
- val deltaY = y - yAvg
- count += 1
- xAvg += deltaX / count
- yAvg += deltaY / count
- Ck += deltaX * (y - yAvg)
- MkX += deltaX * (x - xAvg)
- MkY += deltaY * (y - yAvg)
-
- buffer.setDouble(mutableAggBufferOffset, xAvg)
- buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg)
- buffer.setDouble(mutableAggBufferOffsetPlus2, Ck)
- buffer.setDouble(mutableAggBufferOffsetPlus3, MkX)
- buffer.setDouble(mutableAggBufferOffsetPlus4, MkY)
- buffer.setLong(mutableAggBufferOffsetPlus5, count)
- }
- }
-
- // Merge counters from other partitions. Formula can be found at:
- // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
- override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
- val count2 = buffer2.getLong(inputAggBufferOffsetPlus5)
-
- // We only go to merge two buffers if there is at least one record aggregated in buffer2.
- // We don't need to check count in buffer1 because if count2 is more than zero, totalCount
- // is more than zero too, then we won't get a divide by zero exception.
- if (count2 > 0) {
- var xAvg = buffer1.getDouble(mutableAggBufferOffset)
- var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1)
- var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2)
- var MkX = buffer1.getDouble(mutableAggBufferOffsetPlus3)
- var MkY = buffer1.getDouble(mutableAggBufferOffsetPlus4)
- var count = buffer1.getLong(mutableAggBufferOffsetPlus5)
-
- val xAvg2 = buffer2.getDouble(inputAggBufferOffset)
- val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1)
- val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2)
- val MkX2 = buffer2.getDouble(inputAggBufferOffsetPlus3)
- val MkY2 = buffer2.getDouble(inputAggBufferOffsetPlus4)
-
- val totalCount = count + count2
- val deltaX = xAvg - xAvg2
- val deltaY = yAvg - yAvg2
- Ck += Ck2 + deltaX * deltaY * count / totalCount * count2
- xAvg = (xAvg * count + xAvg2 * count2) / totalCount
- yAvg = (yAvg * count + yAvg2 * count2) / totalCount
- MkX += MkX2 + deltaX * deltaX * count / totalCount * count2
- MkY += MkY2 + deltaY * deltaY * count / totalCount * count2
- count = totalCount
-
- buffer1.setDouble(mutableAggBufferOffset, xAvg)
- buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg)
- buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck)
- buffer1.setDouble(mutableAggBufferOffsetPlus3, MkX)
- buffer1.setDouble(mutableAggBufferOffsetPlus4, MkY)
- buffer1.setLong(mutableAggBufferOffsetPlus5, count)
- }
- }
-
- override def eval(buffer: InternalRow): Any = {
- val count = buffer.getLong(mutableAggBufferOffsetPlus5)
- if (count > 0) {
- val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
- val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3)
- val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
- val corr = Ck / math.sqrt(MkX * MkY)
- if (corr.isNaN) {
- null
- } else {
- corr
- }
- } else {
- null
- }
- }
+ override def prettyName: String = "corr"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
index f53b01be2a0d5..c175a8c4c77b3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
@@ -17,182 +17,79 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
/**
* Compute the covariance between two expressions.
* When applied on empty data (i.e., count is zero), it returns NULL.
- *
*/
-abstract class Covariance(left: Expression, right: Expression) extends ImperativeAggregate
- with Serializable {
- override def children: Seq[Expression] = Seq(left, right)
+abstract class Covariance(x: Expression, y: Expression) extends DeclarativeAggregate {
+ override def children: Seq[Expression] = Seq(x, y)
override def nullable: Boolean = true
-
override def dataType: DataType = DoubleType
-
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
- override def checkInputDataTypes(): TypeCheckResult = {
- if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) {
- TypeCheckResult.TypeCheckSuccess
- } else {
- TypeCheckResult.TypeCheckFailure(
- s"covariance requires that both arguments are double type, " +
- s"not (${left.dataType}, ${right.dataType}).")
- }
- }
-
- override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
-
- override def inputAggBufferAttributes: Seq[AttributeReference] = {
- aggBufferAttributes.map(_.newInstance())
- }
-
- override val aggBufferAttributes: Seq[AttributeReference] = Seq(
- AttributeReference("xAvg", DoubleType)(),
- AttributeReference("yAvg", DoubleType)(),
- AttributeReference("Ck", DoubleType)(),
- AttributeReference("count", LongType)())
-
- // Local cache of mutableAggBufferOffset(s) that will be used in update and merge
- val xAvgOffset = mutableAggBufferOffset
- val yAvgOffset = mutableAggBufferOffset + 1
- val CkOffset = mutableAggBufferOffset + 2
- val countOffset = mutableAggBufferOffset + 3
-
- // Local cache of inputAggBufferOffset(s) that will be used in update and merge
- val inputXAvgOffset = inputAggBufferOffset
- val inputYAvgOffset = inputAggBufferOffset + 1
- val inputCkOffset = inputAggBufferOffset + 2
- val inputCountOffset = inputAggBufferOffset + 3
-
- override def initialize(buffer: MutableRow): Unit = {
- buffer.setDouble(xAvgOffset, 0.0)
- buffer.setDouble(yAvgOffset, 0.0)
- buffer.setDouble(CkOffset, 0.0)
- buffer.setLong(countOffset, 0L)
- }
-
- override def update(buffer: MutableRow, input: InternalRow): Unit = {
- val leftEval = left.eval(input)
- val rightEval = right.eval(input)
-
- if (leftEval != null && rightEval != null) {
- val x = leftEval.asInstanceOf[Double]
- val y = rightEval.asInstanceOf[Double]
-
- var xAvg = buffer.getDouble(xAvgOffset)
- var yAvg = buffer.getDouble(yAvgOffset)
- var Ck = buffer.getDouble(CkOffset)
- var count = buffer.getLong(countOffset)
-
- val deltaX = x - xAvg
- val deltaY = y - yAvg
- count += 1
- xAvg += deltaX / count
- yAvg += deltaY / count
- Ck += deltaX * (y - yAvg)
-
- buffer.setDouble(xAvgOffset, xAvg)
- buffer.setDouble(yAvgOffset, yAvg)
- buffer.setDouble(CkOffset, Ck)
- buffer.setLong(countOffset, count)
- }
+ protected val n = AttributeReference("n", DoubleType, nullable = false)()
+ protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)()
+ protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)()
+ protected val ck = AttributeReference("ck", DoubleType, nullable = false)()
+
+ override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, yAvg, ck)
+
+ override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0))
+
+ override lazy val updateExpressions: Seq[Expression] = {
+ val newN = n + Literal(1.0)
+ val dx = x - xAvg
+ val dy = y - yAvg
+ val dyN = dy / newN
+ val newXAvg = xAvg + dx / newN
+ val newYAvg = yAvg + dyN
+ val newCk = ck + dx * (y - newYAvg)
+
+ val isNull = IsNull(x) || IsNull(y)
+ Seq(
+ If(isNull, n, newN),
+ If(isNull, xAvg, newXAvg),
+ If(isNull, yAvg, newYAvg),
+ If(isNull, ck, newCk)
+ )
}
- // Merge counters from other partitions. Formula can be found at:
- // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
- override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
- val count2 = buffer2.getLong(inputCountOffset)
-
- // We only go to merge two buffers if there is at least one record aggregated in buffer2.
- // We don't need to check count in buffer1 because if count2 is more than zero, totalCount
- // is more than zero too, then we won't get a divide by zero exception.
- if (count2 > 0) {
- var xAvg = buffer1.getDouble(xAvgOffset)
- var yAvg = buffer1.getDouble(yAvgOffset)
- var Ck = buffer1.getDouble(CkOffset)
- var count = buffer1.getLong(countOffset)
+ override val mergeExpressions: Seq[Expression] = {
- val xAvg2 = buffer2.getDouble(inputXAvgOffset)
- val yAvg2 = buffer2.getDouble(inputYAvgOffset)
- val Ck2 = buffer2.getDouble(inputCkOffset)
+ val n1 = n.left
+ val n2 = n.right
+ val newN = n1 + n2
+ val dx = xAvg.right - xAvg.left
+ val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN)
+ val dy = yAvg.right - yAvg.left
+ val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN)
+ val newXAvg = xAvg.left + dxN * n2
+ val newYAvg = yAvg.left + dyN * n2
+ val newCk = ck.left + ck.right + dx * dyN * n1 * n2
- val totalCount = count + count2
- val deltaX = xAvg - xAvg2
- val deltaY = yAvg - yAvg2
- Ck += Ck2 + deltaX * deltaY * count / totalCount * count2
- xAvg = (xAvg * count + xAvg2 * count2) / totalCount
- yAvg = (yAvg * count + yAvg2 * count2) / totalCount
- count = totalCount
-
- buffer1.setDouble(xAvgOffset, xAvg)
- buffer1.setDouble(yAvgOffset, yAvg)
- buffer1.setDouble(CkOffset, Ck)
- buffer1.setLong(countOffset, count)
- }
+ Seq(newN, newXAvg, newYAvg, newCk)
}
}
-case class CovSample(
- left: Expression,
- right: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends Covariance(left, right) {
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
-
- override def eval(buffer: InternalRow): Any = {
- val count = buffer.getLong(countOffset)
- if (count > 1) {
- val Ck = buffer.getDouble(CkOffset)
- val cov = Ck / (count - 1)
- if (cov.isNaN) {
- null
- } else {
- cov
- }
- } else {
- null
- }
+case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) {
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ ck / n)
}
+ override def prettyName: String = "covar_pop"
}
-case class CovPopulation(
- left: Expression,
- right: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends Covariance(left, right) {
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
- override def eval(buffer: InternalRow): Any = {
- val count = buffer.getLong(countOffset)
- if (count > 0) {
- val Ck = buffer.getDouble(CkOffset)
- if (Ck.isNaN) {
- null
- } else {
- Ck / count
- }
- } else {
- null
- }
+case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) {
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType),
+ If(n === Literal(1.0), Literal(Double.NaN),
+ ck / (n - Literal(1.0))))
}
+ override def prettyName: String = "covar_samp"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
deleted file mode 100644
index c2bf2cb94116c..0000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.aggregate
-
-import org.apache.spark.sql.catalyst.expressions._
-
-case class Kurtosis(child: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends CentralMomentAgg(child) {
-
- def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
-
- override def prettyName: String = "kurtosis"
-
- override protected val momentOrder = 4
-
- // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
- require(moments.length == momentOrder + 1,
- s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
- val m2 = moments(2)
- val m4 = moments(4)
-
- if (n == 0.0) {
- null
- } else if (m2 == 0.0) {
- Double.NaN
- } else {
- n * m4 / (m2 * m2) - 3.0
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
deleted file mode 100644
index 9411bcea2539a..0000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.aggregate
-
-import org.apache.spark.sql.catalyst.expressions._
-
-case class Skewness(child: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends CentralMomentAgg(child) {
-
- def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
-
- override def prettyName: String = "skewness"
-
- override protected val momentOrder = 3
-
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
- require(moments.length == momentOrder + 1,
- s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
- val m2 = moments(2)
- val m3 = moments(3)
-
- if (n == 0.0) {
- null
- } else if (m2 == 0.0) {
- Double.NaN
- } else {
- math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2)
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
deleted file mode 100644
index eec79a9033e36..0000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
+++ /dev/null
@@ -1,81 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.aggregate
-
-import org.apache.spark.sql.catalyst.expressions._
-
-case class StddevSamp(child: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends CentralMomentAgg(child) {
-
- def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
-
- override def prettyName: String = "stddev_samp"
-
- override protected val momentOrder = 2
-
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
- require(moments.length == momentOrder + 1,
- s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
-
- if (n == 0.0) {
- null
- } else if (n == 1.0) {
- Double.NaN
- } else {
- math.sqrt(moments(2) / (n - 1.0))
- }
- }
-}
-
-case class StddevPop(
- child: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends CentralMomentAgg(child) {
-
- def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
-
- override def prettyName: String = "stddev_pop"
-
- override protected val momentOrder = 2
-
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
- require(moments.length == momentOrder + 1,
- s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
-
- if (n == 0.0) {
- null
- } else {
- math.sqrt(moments(2) / n)
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
deleted file mode 100644
index cf3a740305391..0000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
+++ /dev/null
@@ -1,81 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.aggregate
-
-import org.apache.spark.sql.catalyst.expressions._
-
-case class VarianceSamp(child: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends CentralMomentAgg(child) {
-
- def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
-
- override def prettyName: String = "var_samp"
-
- override protected val momentOrder = 2
-
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
- require(moments.length == momentOrder + 1,
- s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
-
- if (n == 0.0) {
- null
- } else if (n == 1.0) {
- Double.NaN
- } else {
- moments(2) / (n - 1.0)
- }
- }
-}
-
-case class VariancePop(
- child: Expression,
- mutableAggBufferOffset: Int = 0,
- inputAggBufferOffset: Int = 0)
- extends CentralMomentAgg(child) {
-
- def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
-
- override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
- copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
- override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
- copy(inputAggBufferOffset = newInputAggBufferOffset)
-
- override def prettyName: String = "var_pop"
-
- override protected val momentOrder = 2
-
- override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
- require(moments.length == momentOrder + 1,
- s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")
-
- if (n == 0.0) {
- null
- } else {
- moments(2) / n
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 2747c315ad374..63e19564dd861 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -55,6 +55,20 @@ class CodegenContext {
*/
val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]()
+ /**
+ * Add an object to `references`, create a class member to access it.
+ *
+ * Returns the name of class member.
+ */
+ def addReferenceObj(name: String, obj: Any, className: String = null): String = {
+ val term = freshName(name)
+ val idx = references.length
+ references += obj
+ val clsName = Option(className).getOrElse(obj.getClass.getName)
+ addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];")
+ term
+ }
+
/**
* Holding a list of generated columns as input of current operator, will be used by
* BoundReference to generate code.
@@ -142,16 +156,34 @@ class CodegenContext {
/** The variable name of the input row in generated code. */
final var INPUT_ROW = "i"
- private val curId = new java.util.concurrent.atomic.AtomicInteger()
+ /**
+ * The map from a variable name to it's next ID.
+ */
+ private val freshNameIds = new mutable.HashMap[String, Int]
+ freshNameIds += INPUT_ROW -> 1
/**
- * Returns a term name that is unique within this instance of a `CodeGenerator`.
- *
- * (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
- * function.)
+ * A prefix used to generate fresh name.
+ */
+ var freshNamePrefix = ""
+
+ /**
+ * Returns a term name that is unique within this instance of a `CodegenContext`.
*/
- def freshName(prefix: String): String = {
- s"$prefix${curId.getAndIncrement}"
+ def freshName(name: String): String = synchronized {
+ val fullName = if (freshNamePrefix == "") {
+ name
+ } else {
+ s"${freshNamePrefix}_$name"
+ }
+ if (freshNameIds.contains(fullName)) {
+ val id = freshNameIds(fullName)
+ freshNameIds(fullName) = id + 1
+ s"$fullName$id"
+ } else {
+ freshNameIds += fullName -> 1
+ fullName
+ }
}
/**
@@ -189,6 +221,39 @@ class CodegenContext {
}
}
+ /**
+ * Update a column in MutableRow from ExprCode.
+ */
+ def updateColumn(
+ row: String,
+ dataType: DataType,
+ ordinal: Int,
+ ev: ExprCode,
+ nullable: Boolean): String = {
+ if (nullable) {
+ // Can't call setNullAt on DecimalType, because we need to keep the offset
+ if (dataType.isInstanceOf[DecimalType]) {
+ s"""
+ if (!${ev.isNull}) {
+ ${setColumn(row, dataType, ordinal, ev.value)};
+ } else {
+ ${setColumn(row, dataType, ordinal, "null")};
+ }
+ """
+ } else {
+ s"""
+ if (!${ev.isNull}) {
+ ${setColumn(row, dataType, ordinal, ev.value)};
+ } else {
+ $row.setNullAt($ordinal);
+ }
+ """
+ }
+ } else {
+ s"""${setColumn(row, dataType, ordinal, ev.value)};"""
+ }
+ }
+
/**
* Returns the name used in accessor and setter for a Java primitive type.
*/
@@ -346,17 +411,37 @@ class CodegenContext {
}
/**
- * Generates code for greater of two expressions.
- *
- * @param dataType data type of the expressions
- * @param c1 name of the variable of expression 1's output
- * @param c2 name of the variable of expression 2's output
- */
+ * Generates code for greater of two expressions.
+ *
+ * @param dataType data type of the expressions
+ * @param c1 name of the variable of expression 1's output
+ * @param c2 name of the variable of expression 2's output
+ */
def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match {
case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2"
case _ => s"(${genComp(dataType, c1, c2)}) > 0"
}
+ /**
+ * Generates code to do null safe execution, i.e. only execute the code when the input is not
+ * null by adding null check if necessary.
+ *
+ * @param nullable used to decide whether we should add null check or not.
+ * @param isNull the code to check if the input is null.
+ * @param execute the code that should only be executed when the input is not null.
+ */
+ def nullSafeExec(nullable: Boolean, isNull: String)(execute: String): String = {
+ if (nullable) {
+ s"""
+ if (!$isNull) {
+ $execute
+ }
+ """
+ } else {
+ "\n" + execute
+ }
+ }
+
/**
* List of java data types that have special accessors and setters in [[InternalRow]].
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index d9fe76133c6ef..5b4dc8df8622b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -88,31 +88,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
val updates = validExpr.zip(index).map {
case (e, i) =>
- if (e.nullable) {
- if (e.dataType.isInstanceOf[DecimalType]) {
- // Can't call setNullAt on DecimalType, because we need to keep the offset
- s"""
- if (this.isNull_$i) {
- ${ctx.setColumn("mutableRow", e.dataType, i, null)};
- } else {
- ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
- }
- """
- } else {
- s"""
- if (this.isNull_$i) {
- mutableRow.setNullAt($i);
- } else {
- ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
- }
- """
- }
- } else {
- s"""
- ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
- """
- }
-
+ val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i")
+ ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
}
val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 5256baaf432a2..6b24fae9f3f1c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -173,22 +173,26 @@ case class GetArrayStructFields(
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, eval => {
+ val n = ctx.freshName("n")
+ val values = ctx.freshName("values")
+ val j = ctx.freshName("j")
+ val row = ctx.freshName("row")
s"""
- final int n = $eval.numElements();
- final Object[] values = new Object[n];
- for (int j = 0; j < n; j++) {
- if ($eval.isNullAt(j)) {
- values[j] = null;
+ final int $n = $eval.numElements();
+ final Object[] $values = new Object[$n];
+ for (int $j = 0; $j < $n; $j++) {
+ if ($eval.isNullAt($j)) {
+ $values[$j] = null;
} else {
- final InternalRow row = $eval.getStruct(j, $numFields);
- if (row.isNullAt($ordinal)) {
- values[j] = null;
+ final InternalRow $row = $eval.getStruct($j, $numFields);
+ if ($row.isNullAt($ordinal)) {
+ $values[$j] = null;
} else {
- values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)};
+ $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)};
}
}
}
- ${ev.value} = new $arrayClass(values);
+ ${ev.value} = new $arrayClass($values);
"""
})
}
@@ -218,7 +222,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
val baseValue = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Number].intValue()
- if (index >= baseValue.numElements() || index < 0) {
+ if (index >= baseValue.numElements() || index < 0 || baseValue.isNullAt(index)) {
null
} else {
baseValue.get(index, dataType)
@@ -227,12 +231,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
+ val index = ctx.freshName("index")
s"""
- final int index = (int) $eval2;
- if (index >= $eval1.numElements() || index < 0 || $eval1.isNullAt(index)) {
+ final int $index = (int) $eval2;
+ if ($index >= $eval1.numElements() || $index < 0 || $eval1.isNullAt($index)) {
${ev.isNull} = true;
} else {
- ${ev.value} = ${ctx.getValue(eval1, dataType, "index")};
+ ${ev.value} = ${ctx.getValue(eval1, dataType, index)};
}
"""
})
@@ -267,6 +272,7 @@ case class GetMapValue(child: Expression, key: Expression)
val map = value.asInstanceOf[MapData]
val length = map.numElements()
val keys = map.keyArray()
+ val values = map.valueArray()
var i = 0
var found = false
@@ -278,10 +284,10 @@ case class GetMapValue(child: Expression, key: Expression)
}
}
- if (!found) {
+ if (!found || values.isNullAt(i)) {
null
} else {
- map.valueArray().get(i, dataType)
+ values.get(i, dataType)
}
}
@@ -291,10 +297,12 @@ case class GetMapValue(child: Expression, key: Expression)
val keys = ctx.freshName("keys")
val found = ctx.freshName("found")
val key = ctx.freshName("key")
+ val values = ctx.freshName("values")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
final int $length = $eval1.numElements();
final ArrayData $keys = $eval1.keyArray();
+ final ArrayData $values = $eval1.valueArray();
int $index = 0;
boolean $found = false;
@@ -307,10 +315,10 @@ case class GetMapValue(child: Expression, key: Expression)
}
}
- if ($found) {
- ${ev.value} = ${ctx.getValue(eval1 + ".valueArray()", dataType, index)};
- } else {
+ if (!$found || $values.isNullAt($index)) {
${ev.isNull} = true;
+ } else {
+ ${ev.value} = ${ctx.getValue(values, dataType, index)};
}
"""
})
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 493e0aae01af7..f4ccadd9c563e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -325,36 +325,50 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
ev.isNull = "false"
- val childrenHash = children.zipWithIndex.map {
- case (child, dt) =>
- val childGen = child.gen(ctx)
- val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx)
- s"""
- ${childGen.code}
- if (!${childGen.isNull}) {
- ${childHash.code}
- ${ev.value} = ${childHash.value};
- }
- """
+ val childrenHash = children.map { child =>
+ val childGen = child.gen(ctx)
+ childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
+ computeHash(childGen.value, child.dataType, ev.value, ctx)
+ }
}.mkString("\n")
+
s"""
int ${ev.value} = $seed;
$childrenHash
"""
}
+ private def nullSafeElementHash(
+ input: String,
+ index: String,
+ nullable: Boolean,
+ elementType: DataType,
+ result: String,
+ ctx: CodegenContext): String = {
+ val element = ctx.freshName("element")
+
+ ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") {
+ s"""
+ final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)};
+ ${computeHash(element, elementType, result, ctx)}
+ """
+ }
+ }
+
private def computeHash(
input: String,
dataType: DataType,
- seed: String,
- ctx: CodegenContext): ExprCode = {
+ result: String,
+ ctx: CodegenContext): String = {
val hasher = classOf[Murmur3_x86_32].getName
- def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)")
- def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)")
- def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v)
+
+ def hashInt(i: String): String = s"$result = $hasher.hashInt($i, $result);"
+ def hashLong(l: String): String = s"$result = $hasher.hashLong($l, $result);"
+ def hashBytes(b: String): String =
+ s"$result = $hasher.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length, $result);"
dataType match {
- case NullType => inlineValue(seed)
+ case NullType => ""
case BooleanType => hashInt(s"$input ? 1 : 0")
case ByteType | ShortType | IntegerType | DateType => hashInt(input)
case LongType | TimestampType => hashLong(input)
@@ -365,91 +379,66 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
hashLong(s"$input.toUnscaledLong()")
} else {
val bytes = ctx.freshName("bytes")
- val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();"
- val offset = "Platform.BYTE_ARRAY_OFFSET"
- val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)"
- ExprCode(code, "false", result)
+ s"""
+ final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();
+ ${hashBytes(bytes)}
+ """
}
case CalendarIntervalType =>
- val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)"
- val monthsHash = s"$hasher.hashInt($input.months, $microsecondsHash)"
- inlineValue(monthsHash)
- case BinaryType =>
- val offset = "Platform.BYTE_ARRAY_OFFSET"
- inlineValue(s"$hasher.hashUnsafeBytes($input, $offset, $input.length, $seed)")
+ val microsecondsHash = s"$hasher.hashLong($input.microseconds, $result)"
+ s"$result = $hasher.hashInt($input.months, $microsecondsHash);"
+ case BinaryType => hashBytes(input)
case StringType =>
val baseObject = s"$input.getBaseObject()"
val baseOffset = s"$input.getBaseOffset()"
val numBytes = s"$input.numBytes()"
- inlineValue(s"$hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $seed)")
+ s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);"
- case ArrayType(et, _) =>
- val result = ctx.freshName("result")
+ case ArrayType(et, containsNull) =>
val index = ctx.freshName("index")
- val element = ctx.freshName("element")
- val elementHash = computeHash(element, et, result, ctx)
- val code =
- s"""
- int $result = $seed;
- for (int $index = 0; $index < $input.numElements(); $index++) {
- if (!$input.isNullAt($index)) {
- final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)};
- ${elementHash.code}
- $result = ${elementHash.value};
- }
- }
- """
- ExprCode(code, "false", result)
+ s"""
+ for (int $index = 0; $index < $input.numElements(); $index++) {
+ ${nullSafeElementHash(input, index, containsNull, et, result, ctx)}
+ }
+ """
- case MapType(kt, vt, _) =>
- val result = ctx.freshName("result")
+ case MapType(kt, vt, valueContainsNull) =>
val index = ctx.freshName("index")
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
- val key = ctx.freshName("key")
- val value = ctx.freshName("value")
- val keyHash = computeHash(key, kt, result, ctx)
- val valueHash = computeHash(value, vt, result, ctx)
- val code =
- s"""
- int $result = $seed;
- final ArrayData $keys = $input.keyArray();
- final ArrayData $values = $input.valueArray();
- for (int $index = 0; $index < $input.numElements(); $index++) {
- final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)};
- ${keyHash.code}
- $result = ${keyHash.value};
- if (!$values.isNullAt($index)) {
- final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)};
- ${valueHash.code}
- $result = ${valueHash.value};
- }
- }
- """
- ExprCode(code, "false", result)
+ s"""
+ final ArrayData $keys = $input.keyArray();
+ final ArrayData $values = $input.valueArray();
+ for (int $index = 0; $index < $input.numElements(); $index++) {
+ ${nullSafeElementHash(keys, index, false, kt, result, ctx)}
+ ${nullSafeElementHash(values, index, valueContainsNull, vt, result, ctx)}
+ }
+ """
case StructType(fields) =>
- val result = ctx.freshName("result")
- val fieldsHash = fields.map(_.dataType).zipWithIndex.map {
- case (dt, index) =>
- val field = ctx.freshName("field")
- val fieldHash = computeHash(field, dt, result, ctx)
- s"""
- if (!$input.isNullAt($index)) {
- final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)};
- ${fieldHash.code}
- $result = ${fieldHash.value};
- }
- """
+ fields.zipWithIndex.map { case (field, index) =>
+ nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
}.mkString("\n")
- val code =
- s"""
- int $result = $seed;
- $fieldsHash
- """
- ExprCode(code, "false", result)
- case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx)
+ case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, result, ctx)
}
}
}
+
+/**
+ * Print the result of an expression to stderr (used for debugging codegen).
+ */
+case class PrintToStderr(child: Expression) extends UnaryExpression {
+
+ override def dataType: DataType = child.dataType
+
+ protected override def nullSafeEval(input: Any): Any = input
+
+ override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
+ nullSafeCodeGen(ctx, ev, c =>
+ s"""
+ | System.err.println("Result of ${child.simpleString} is " + $c);
+ | ${ev.value} = $c;
+ """.stripMargin)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 79fe0033b71ab..fef6825b2db5e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -365,7 +365,7 @@ object MapObjects {
* to handle collection elements.
* @param inputData An expression that when evaluted returns a collection object.
*/
-case class MapObjects(
+case class MapObjects private(
loopVar: LambdaVariable,
lambdaFunction: Expression,
inputData: Expression) extends Expression {
@@ -637,8 +637,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
* `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all
* non-null `s`, `s.i` can't be null.
*/
-case class AssertNotNull(
- child: Expression, parentType: String, fieldName: String, fieldType: String)
+case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
extends UnaryExpression {
override def dataType: DataType = child.dataType
@@ -651,6 +650,14 @@ case class AssertNotNull(
override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val childGen = child.gen(ctx)
+ val errMsg = "Null value appeared in non-nullable field:" +
+ walkedTypePath.mkString("\n", "\n", "\n") +
+ "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
+ "please try to use scala.Option[_] or other nullable types " +
+ "(e.g. java.lang.Integer instead of int/scala.Int)."
+ val idx = ctx.references.length
+ ctx.references += errMsg
+
ev.isNull = "false"
ev.value = childGen.value
@@ -658,12 +665,7 @@ case class AssertNotNull(
${childGen.code}
if (${childGen.isNull}) {
- throw new RuntimeException(
- "Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " +
- "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
- "please try to use scala.Option[_] or other nullable types " +
- "(e.g. java.lang.Integer instead of int/scala.Int)."
- );
+ throw new RuntimeException((String) references[$idx]);
}
"""
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 387d979484f2c..be6b2530ef39e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -233,18 +233,6 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGeneri
override def copy(): GenericInternalRow = this
}
-/**
- * This is used for serialization of Python DataFrame
- */
-class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType)
- extends GenericInternalRow(values) {
-
- /** No-arg constructor for serialization. */
- protected def this() = this(null, null)
-
- def fieldIndex(name: String): Int = schema.fieldIndex(name)
-}
-
class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow {
/** No-arg constructor for serialization. */
protected def this() = this(null)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 6addc2080648b..902e18081bddf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions}
-import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter}
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types._
@@ -52,8 +52,10 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
// since the other rules might make two separate Unions operators adjacent.
Batch("Union", Once,
CombineUnions) ::
+ Batch("Replace Operators", FixedPoint(100),
+ ReplaceIntersectWithSemiJoin,
+ ReplaceDistinctWithAggregate) ::
Batch("Aggregate", FixedPoint(100),
- ReplaceDistinctWithAggregate,
RemoveLiteralFromGroupExpressions) ::
Batch("Operator Optimizations", FixedPoint(100),
// Operator push down
@@ -66,7 +68,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
PushPredicateThroughAggregate,
ColumnPruning,
// Operator combine
- ProjectCollapsing,
+ CollapseRepartition,
+ CollapseProject,
CombineFilters,
CombineLimits,
CombineUnions,
@@ -116,26 +119,24 @@ object SamplePushDown extends Rule[LogicalPlan] {
*/
object EliminateSerialization extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case m @ MapPartitions(_, input, _, child: ObjectOperator)
- if !input.isInstanceOf[Attribute] && m.input.dataType == child.outputObject.dataType =>
+ case m @ MapPartitions(_, deserializer, _, child: ObjectOperator)
+ if !deserializer.isInstanceOf[Attribute] &&
+ deserializer.dataType == child.outputObject.dataType =>
val childWithoutSerialization = child.withObjectOutput
- m.copy(input = childWithoutSerialization.output.head, child = childWithoutSerialization)
+ m.copy(
+ deserializer = childWithoutSerialization.output.head,
+ child = childWithoutSerialization)
}
}
/**
- * Pushes certain operations to both sides of a Union, Intersect or Except operator.
+ * Pushes certain operations to both sides of a Union or Except operator.
* Operations that are safe to pushdown are listed as follows.
* Union:
* Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is
* safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT,
* we will not be able to pushdown Projections.
*
- * Intersect:
- * It is not safe to pushdown Projections through it because we need to get the
- * intersect of rows by comparing the entire rows. It is fine to pushdown Filters
- * with deterministic condition.
- *
* Except:
* It is not safe to pushdown Projections through it because we need to get the
* intersect of rows by comparing the entire rows. It is fine to pushdown Filters
@@ -153,7 +154,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
/**
* Rewrites an expression so that it can be pushed to the right side of a
- * Union, Intersect or Except operator. This method relies on the fact that the output attributes
+ * Union or Except operator. This method relies on the fact that the output attributes
* of a union/intersect/except are always equal to the left child's output.
*/
private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = {
@@ -210,17 +211,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
Filter(nondeterministic, Union(newFirstChild +: newOtherChildren))
- // Push down filter through INTERSECT
- case Filter(condition, Intersect(left, right)) =>
- val (deterministic, nondeterministic) = partitionByDeterministic(condition)
- val rewrites = buildRewrites(left, right)
- Filter(nondeterministic,
- Intersect(
- Filter(deterministic, left),
- Filter(pushToRight(deterministic, rewrites), right)
- )
- )
-
// Push down filter through EXCEPT
case Filter(condition, Except(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
@@ -336,7 +326,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
* Combines two adjacent [[Project]] operators into one and perform alias substitution,
* merging the expressions into one single expression.
*/
-object ProjectCollapsing extends Rule[LogicalPlan] {
+object CollapseProject extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p @ Project(projectList1, Project(projectList2, child)) =>
@@ -404,6 +394,16 @@ object ProjectCollapsing extends Rule[LogicalPlan] {
}
}
+/**
+ * Combines adjacent [[Repartition]] operators by keeping only the last one.
+ */
+object CollapseRepartition extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case r @ Repartition(numPartitions, shuffle, Repartition(_, _, child)) =>
+ Repartition(numPartitions, shuffle, child)
+ }
+}
+
/**
* Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition.
* For example, when the expression is just checking to see if a string starts with a given
@@ -871,6 +871,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
/**
* Splits join condition expressions into three categories based on the attributes required
* to evaluate them.
+ *
* @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth)
*/
private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
@@ -919,6 +920,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
(rightFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
case FullOuter => f // DO Nothing for Full Outer Join
+ case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
}
// push down the join filter into sub query scanning if applicable
@@ -953,6 +955,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
Join(newLeft, newRight, LeftOuter, newJoinCond)
case FullOuter => f
+ case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
}
}
}
@@ -1054,6 +1057,27 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
}
}
+/**
+ * Replaces logical [[Intersect]] operator with a left-semi [[Join]] operator.
+ * {{{
+ * SELECT a1, a2 FROM Tab1 INTERSECT SELECT b1, b2 FROM Tab2
+ * ==> SELECT DISTINCT a1, a2 FROM Tab1 LEFT SEMI JOIN Tab2 ON a1<=>b1 AND a2<=>b2
+ * }}}
+ *
+ * Note:
+ * 1. This rule is only applicable to INTERSECT DISTINCT. Do not use it for INTERSECT ALL.
+ * 2. This rule has to be done after de-duplicating the attributes; otherwise, the generated
+ * join conditions will be incorrect.
+ */
+object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case Intersect(left, right) =>
+ assert(left.output.size == right.output.size)
+ val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) }
+ Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And)))
+ }
+}
+
/**
* Removes literals from group expressions in [[Aggregate]], as they have no effect to the result
* but only makes the grouping key bigger.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala
index ec9812414e19f..28f7b10ed6a59 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala
@@ -58,12 +58,12 @@ case class ASTNode(
override val origin: Origin = Origin(Some(line), Some(positionInLine))
/** Source text. */
- lazy val source: String = stream.toString(startIndex, stopIndex)
+ lazy val source: String = stream.toOriginalString(startIndex, stopIndex)
/** Get the source text that remains after this token. */
lazy val remainder: String = {
stream.fill()
- stream.toString(stopIndex + 1, stream.size() - 1).trim()
+ stream.toOriginalString(stopIndex + 1, stream.size() - 1).trim()
}
def text: String = token.getText
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index b43b7ee71e7aa..05f5bdbfc0769 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.plans
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types.{DataType, StructType}
@@ -26,6 +26,56 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
def output: Seq[Attribute]
+ /**
+ * Extracts the relevant constraints from a given set of constraints based on the attributes that
+ * appear in the [[outputSet]].
+ */
+ protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = {
+ constraints
+ .union(constructIsNotNullConstraints(constraints))
+ .filter(constraint =>
+ constraint.references.nonEmpty && constraint.references.subsetOf(outputSet))
+ }
+
+ /**
+ * Infers a set of `isNotNull` constraints from a given set of equality/comparison expressions.
+ * For e.g., if an expression is of the form (`a > 5`), this returns a constraint of the form
+ * `isNotNull(a)`
+ */
+ private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
+ // Currently we only propagate constraints if the condition consists of equality
+ // and ranges. For all other cases, we return an empty set of constraints
+ constraints.map {
+ case EqualTo(l, r) =>
+ Set(IsNotNull(l), IsNotNull(r))
+ case GreaterThan(l, r) =>
+ Set(IsNotNull(l), IsNotNull(r))
+ case GreaterThanOrEqual(l, r) =>
+ Set(IsNotNull(l), IsNotNull(r))
+ case LessThan(l, r) =>
+ Set(IsNotNull(l), IsNotNull(r))
+ case LessThanOrEqual(l, r) =>
+ Set(IsNotNull(l), IsNotNull(r))
+ case _ =>
+ Set.empty[Expression]
+ }.foldLeft(Set.empty[Expression])(_ union _.toSet)
+ }
+
+ /**
+ * A sequence of expressions that describes the data property of the output rows of this
+ * operator. For example, if the output of this operator is column `a`, an example `constraints`
+ * can be `Set(a > 10, a < 20)`.
+ */
+ lazy val constraints: Set[Expression] = getRelevantConstraints(validConstraints)
+
+ /**
+ * This method can be overridden by any child class of QueryPlan to specify a set of constraints
+ * based on the given operator's constraint propagation logic. These constraints are then
+ * canonicalized and filtered automatically to contain only those attributes that appear in the
+ * [[outputSet]]
+ */
+ protected def validConstraints: Set[Expression] = Set.empty
+
/**
* Returns the set of attributes that are output by this node.
*/
@@ -59,6 +109,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
* Runs [[transform]] with `rule` on all expressions present in this query operator.
* Users should not expect a specific directionality. If a specific directionality is needed,
* transformExpressionsDown or transformExpressionsUp should be used.
+ *
* @param rule the rule to be applied to every expression in this operator.
*/
def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = {
@@ -67,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
/**
* Runs [[transformDown]] with `rule` on all expressions present in this query operator.
+ *
* @param rule the rule to be applied to every expression in this operator.
*/
def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = {
@@ -99,6 +151,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
/**
* Runs [[transformUp]] with `rule` on all expressions present in this query operator.
+ *
* @param rule the rule to be applied to every expression in this operator.
* @return
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index a5f6764aef7ce..27a75326eba07 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -60,3 +60,9 @@ case object FullOuter extends JoinType {
case object LeftSemi extends JoinType {
override def sql: String = "LEFT SEMI"
}
+
+case class NaturalJoin(tpe: JoinType) extends JoinType {
+ require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe),
+ "Unsupported natural join type " + tpe)
+ override def sql: String = "NATURAL " + tpe.sql
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 6d859551f8c52..d8944a424156e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -305,6 +305,8 @@ abstract class UnaryNode extends LogicalPlan {
def child: LogicalPlan
override def children: Seq[LogicalPlan] = child :: Nil
+
+ override protected def validConstraints: Set[Expression] = child.constraints
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index e9c970cd08088..57575f9ee09ab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -19,11 +19,23 @@ package org.apache.spark.sql.catalyst.plans.logical
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
+/**
+ * When planning take() or collect() operations, this special node that is inserted at the top of
+ * the logical plan before invoking the query planner.
+ *
+ * Rules can pattern-match on this node in order to apply transformations that only take effect
+ * at the top of the logical query plan.
+ */
+case class ReturnAnswer(child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+}
+
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
@@ -86,15 +98,26 @@ case class Generate(
}
}
-case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
+case class Filter(condition: Expression, child: LogicalPlan)
+ extends UnaryNode with PredicateHelper {
override def output: Seq[Attribute] = child.output
+
+ override protected def validConstraints: Set[Expression] = {
+ child.constraints.union(splitConjunctivePredicates(condition).toSet)
+ }
}
abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
- final override lazy val resolved: Boolean =
- childrenResolved &&
- left.output.length == right.output.length &&
- left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
+
+ protected def leftConstraints: Set[Expression] = left.constraints
+
+ protected def rightConstraints: Set[Expression] = {
+ require(left.output.size == right.output.size)
+ val attributeRewrites = AttributeMap(right.output.zip(left.output))
+ right.constraints.map(_ transform {
+ case a: Attribute => attributeRewrites(a)
+ })
+ }
}
private[sql] object SetOperation {
@@ -103,15 +126,36 @@ private[sql] object SetOperation {
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
+ def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
+
override def output: Seq[Attribute] =
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
}
+
+ override protected def validConstraints: Set[Expression] = {
+ leftConstraints.union(rightConstraints)
+ }
+
+ // Intersect are only resolved if they don't introduce ambiguous expression ids,
+ // since the Optimizer will convert Intersect to Join.
+ override lazy val resolved: Boolean =
+ childrenResolved &&
+ left.output.length == right.output.length &&
+ left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } &&
+ duplicateResolved
}
case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
/** We don't use right.output because those rows get excluded from the set. */
override def output: Seq[Attribute] = left.output
+
+ override protected def validConstraints: Set[Expression] = leftConstraints
+
+ override lazy val resolved: Boolean =
+ childrenResolved &&
+ left.output.length == right.output.length &&
+ left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
}
/** Factory for constructing new `Union` nodes. */
@@ -146,13 +190,36 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
val sizeInBytes = children.map(_.statistics.sizeInBytes).sum
Statistics(sizeInBytes = sizeInBytes)
}
+
+ /**
+ * Maps the constraints containing a given (original) sequence of attributes to those with a
+ * given (reference) sequence of attributes. Given the nature of union, we expect that the
+ * mapping between the original and reference sequences are symmetric.
+ */
+ private def rewriteConstraints(
+ reference: Seq[Attribute],
+ original: Seq[Attribute],
+ constraints: Set[Expression]): Set[Expression] = {
+ require(reference.size == original.size)
+ val attributeRewrites = AttributeMap(original.zip(reference))
+ constraints.map(_ transform {
+ case a: Attribute => attributeRewrites(a)
+ })
+ }
+
+ override protected def validConstraints: Set[Expression] = {
+ children
+ .map(child => rewriteConstraints(children.head.output, child.output, child.constraints))
+ .reduce(_ intersect _)
+ }
}
case class Join(
- left: LogicalPlan,
- right: LogicalPlan,
- joinType: JoinType,
- condition: Option[Expression]) extends BinaryNode {
+ left: LogicalPlan,
+ right: LogicalPlan,
+ joinType: JoinType,
+ condition: Option[Expression])
+ extends BinaryNode with PredicateHelper {
override def output: Seq[Attribute] = {
joinType match {
@@ -169,15 +236,45 @@ case class Join(
}
}
- def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
+ override protected def validConstraints: Set[Expression] = {
+ joinType match {
+ case Inner if condition.isDefined =>
+ left.constraints
+ .union(right.constraints)
+ .union(splitConjunctivePredicates(condition.get).toSet)
+ case LeftSemi if condition.isDefined =>
+ left.constraints
+ .union(splitConjunctivePredicates(condition.get).toSet)
+ case Inner =>
+ left.constraints.union(right.constraints)
+ case LeftSemi =>
+ left.constraints
+ case LeftOuter =>
+ left.constraints
+ case RightOuter =>
+ right.constraints
+ case FullOuter =>
+ Set.empty[Expression]
+ }
+ }
+
+ def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
// Joins are only resolved if they don't introduce ambiguous expression ids.
- override lazy val resolved: Boolean = {
+ // NaturalJoin should be ready for resolution only if everything else is resolved here
+ lazy val resolvedExceptNatural: Boolean = {
childrenResolved &&
expressions.forall(_.resolved) &&
- selfJoinResolved &&
+ duplicateResolved &&
condition.forall(_.dataType == BooleanType)
}
+
+ // if not a natural join, use `resolvedExceptNatural`. if it is a natural join, we still need
+ // to eliminate natural before we mark it resolved.
+ override lazy val resolved: Boolean = joinType match {
+ case NaturalJoin(_) => false
+ case _ => resolvedExceptNatural
+ }
}
/**
@@ -249,7 +346,7 @@ case class Range(
end: Long,
step: Long,
numSlices: Int,
- output: Seq[Attribute]) extends LeafNode {
+ output: Seq[Attribute]) extends LeafNode with MultiInstanceRelation {
require(step != 0, "step cannot be 0")
val numElements: BigInt = {
val safeStart = BigInt(start)
@@ -262,6 +359,9 @@ case class Range(
}
}
+ override def newInstance(): Range =
+ Range(start, end, step, numSlices, output.map(_.newInstance()))
+
override def statistics: Statistics = {
val sizeInBytes = LongType.defaultSize * numElements
Statistics( sizeInBytes = sizeInBytes )
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 760348052739c..3f97662957b8e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.ObjectType
+import org.apache.spark.sql.types.{ObjectType, StructType}
/**
* A trait for logical operators that apply user defined functions to domain objects.
@@ -30,6 +30,15 @@ trait ObjectOperator extends LogicalPlan {
/** The serializer that is used to produce the output of this operator. */
def serializer: Seq[NamedExpression]
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ /**
+ * An [[ObjectOperator]] may have one or more deserializers to convert internal rows to objects.
+ * It must also provide the attributes that are available during the resolution of each
+ * deserializer.
+ */
+ def deserializers: Seq[(Expression, Seq[Attribute])]
+
/**
* The object type that is produced by the user defined function. Note that the return type here
* is the same whether or not the operator is output serialized data.
@@ -44,13 +53,13 @@ trait ObjectOperator extends LogicalPlan {
def withObjectOutput: LogicalPlan = if (output.head.dataType.isInstanceOf[ObjectType]) {
this
} else {
- withNewSerializer(outputObject)
+ withNewSerializer(outputObject :: Nil)
}
/** Returns a copy of this operator with a different serializer. */
- def withNewSerializer(newSerializer: NamedExpression): LogicalPlan = makeCopy {
+ def withNewSerializer(newSerializer: Seq[NamedExpression]): LogicalPlan = makeCopy {
productIterator.map {
- case c if c == serializer => newSerializer :: Nil
+ case c if c == serializer => newSerializer
case other: AnyRef => other
}.toArray
}
@@ -70,15 +79,16 @@ object MapPartitions {
/**
* A relation produced by applying `func` to each partition of the `child`.
- * @param input used to extract the input to `func` from an input row.
+ *
+ * @param deserializer used to extract the input to `func` from an input row.
* @param serializer use to serialize the output of `func`.
*/
case class MapPartitions(
func: Iterator[Any] => Iterator[Any],
- input: Expression,
+ deserializer: Expression,
serializer: Seq[NamedExpression],
child: LogicalPlan) extends UnaryNode with ObjectOperator {
- override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+ override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
}
/** Factory for constructing new `AppendColumn` nodes. */
@@ -97,16 +107,21 @@ object AppendColumns {
/**
* A relation produced by applying `func` to each partition of the `child`, concatenating the
* resulting columns at the end of the input row.
- * @param input used to extract the input to `func` from an input row.
+ *
+ * @param deserializer used to extract the input to `func` from an input row.
* @param serializer use to serialize the output of `func`.
*/
case class AppendColumns(
func: Any => Any,
- input: Expression,
+ deserializer: Expression,
serializer: Seq[NamedExpression],
child: LogicalPlan) extends UnaryNode with ObjectOperator {
+
override def output: Seq[Attribute] = child.output ++ newColumns
+
def newColumns: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
}
/** Factory for constructing new `MapGroups` nodes. */
@@ -114,6 +129,7 @@ object MapGroups {
def apply[K : Encoder, T : Encoder, U : Encoder](
func: (K, Iterator[T]) => TraversableOnce[U],
groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
child: LogicalPlan): MapGroups = {
new MapGroups(
func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
@@ -121,6 +137,7 @@ object MapGroups {
encoderFor[T].fromRowExpression,
encoderFor[U].namedExpressions,
groupingAttributes,
+ dataAttributes,
child)
}
}
@@ -129,19 +146,22 @@ object MapGroups {
* Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`.
* Func is invoked with an object representation of the grouping key an iterator containing the
* object representation of all the rows with that key.
- * @param keyObject used to extract the key object for each group.
- * @param input used to extract the items in the iterator from an input row.
+ *
+ * @param keyDeserializer used to extract the key object for each group.
+ * @param valueDeserializer used to extract the items in the iterator from an input row.
* @param serializer use to serialize the output of `func`.
*/
case class MapGroups(
func: (Any, Iterator[Any]) => TraversableOnce[Any],
- keyObject: Expression,
- input: Expression,
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
serializer: Seq[NamedExpression],
groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
child: LogicalPlan) extends UnaryNode with ObjectOperator {
- def output: Seq[Attribute] = serializer.map(_.toAttribute)
+ override def deserializers: Seq[(Expression, Seq[Attribute])] =
+ Seq(keyDeserializer -> groupingAttributes, valueDeserializer -> dataAttributes)
}
/** Factory for constructing new `CoGroup` nodes. */
@@ -150,8 +170,12 @@ object CoGroup {
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
+ leftData: Seq[Attribute],
+ rightData: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan): CoGroup = {
+ require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup))
+
CoGroup(
func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]],
encoderFor[Key].fromRowExpression,
@@ -160,6 +184,8 @@ object CoGroup {
encoderFor[Result].namedExpressions,
leftGroup,
rightGroup,
+ leftData,
+ rightData,
left,
right)
}
@@ -171,15 +197,21 @@ object CoGroup {
*/
case class CoGroup(
func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any],
- keyObject: Expression,
- leftObject: Expression,
- rightObject: Expression,
+ keyDeserializer: Expression,
+ leftDeserializer: Expression,
+ rightDeserializer: Expression,
serializer: Seq[NamedExpression],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
+ leftAttr: Seq[Attribute],
+ rightAttr: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan) extends BinaryNode with ObjectOperator {
+
override def producedAttributes: AttributeSet = outputSet
- override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+ override def deserializers: Seq[(Expression, Seq[Attribute])] =
+ // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve
+ // the `keyDeserializer` based on either of them, here we pick the left one.
+ Seq(keyDeserializer -> leftGroup, leftDeserializer -> leftAttr, rightDeserializer -> rightAttr)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 57e1a3c9eb226..2df0683f9fa16 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -512,7 +512,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}
protected def jsonFields: List[JField] = {
- val fieldNames = getConstructorParameters(getClass).map(_._1)
+ val fieldNames = getConstructorParameterNames(getClass)
val fieldValues = productIterator.toSeq ++ otherCopyArgs
assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " +
fieldNames.mkString(", ") + s", values: " + fieldValues.map(_.toString).mkString(", "))
@@ -560,7 +560,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName
// returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper
case p: Product => try {
- val fieldNames = getConstructorParameters(p.getClass).map(_._1)
+ val fieldNames = getConstructorParameterNames(p.getClass)
val fieldValues = p.productIterator.toSeq
assert(fieldNames.length == fieldValues.length)
("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index f18c052b68e37..a159bc6a61415 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -55,6 +55,7 @@ object DateTimeUtils {
// this is year -17999, calculation: 50 * daysIn400Year
final val YearZero = -17999
final val toYearZero = to2001 + 7304850
+ final val TimeZoneGMT = TimeZone.getTimeZone("GMT")
@transient lazy val defaultTimeZone = TimeZone.getDefault
@@ -407,7 +408,7 @@ object DateTimeUtils {
segments(2) < 1 || segments(2) > 31) {
return None
}
- val c = Calendar.getInstance(TimeZone.getTimeZone("GMT"))
+ val c = Calendar.getInstance(TimeZoneGMT)
c.set(segments(0), segments(1) - 1, segments(2), 0, 0, 0)
c.set(Calendar.MILLISECOND, 0)
Some((c.getTimeInMillis / MILLIS_PER_DAY).toInt)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index cf5322125bd72..5dd661ee6b339 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -148,6 +148,28 @@ object DecimalType extends AbstractDataType {
}
}
+ /**
+ * Returns if dt is a DecimalType that fits inside a long
+ */
+ def is64BitDecimalType(dt: DataType): Boolean = {
+ dt match {
+ case t: DecimalType =>
+ t.precision <= Decimal.MAX_LONG_DIGITS
+ case _ => false
+ }
+ }
+
+ /**
+ * Returns if dt is a DecimalType that doesn't fit inside a long
+ */
+ def isByteArrayDecimalType(dt: DataType): Boolean = {
+ dt match {
+ case t: DecimalType =>
+ t.precision > Decimal.MAX_LONG_DIGITS
+ case _ => false
+ }
+ }
+
def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType]
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
index 9e0f9943bc638..66f123682e117 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
@@ -273,4 +273,9 @@ class MetadataBuilder {
map.put(key, value)
this
}
+
+ def remove(key: String): this.type = {
+ map.remove(key)
+ this
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 3bd733fa2d26c..e797d83cb05be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.util.{DataTypeParser, LegacyTypeStringParse
* Example:
* {{{
* import org.apache.spark.sql._
+ * import org.apache.spark.sql.types._
*
* val struct =
* StructType(
@@ -334,6 +335,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
object StructType extends AbstractDataType {
+ private[sql] val metadataKeyForOptionalField = "_OPTIONAL_"
+
override private[sql] def defaultConcreteType: DataType = new StructType
override private[sql] def acceptsType(other: DataType): Boolean = {
@@ -359,6 +362,18 @@ object StructType extends AbstractDataType {
protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType =
StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))
+ def removeMetadata(key: String, dt: DataType): DataType =
+ dt match {
+ case StructType(fields) =>
+ val newFields = fields.map { f =>
+ val mb = new MetadataBuilder()
+ f.copy(dataType = removeMetadata(key, f.dataType),
+ metadata = mb.withMetadata(f.metadata).remove(key).build())
+ }
+ StructType(newFields)
+ case _ => dt
+ }
+
private[sql] def merge(left: DataType, right: DataType): DataType =
(left, right) match {
case (ArrayType(leftElementType, leftContainsNull),
@@ -376,24 +391,32 @@ object StructType extends AbstractDataType {
case (StructType(leftFields), StructType(rightFields)) =>
val newFields = ArrayBuffer.empty[StructField]
+ // This metadata will record the fields that only exist in one of two StructTypes
+ val optionalMeta = new MetadataBuilder()
val rightMapped = fieldsMap(rightFields)
leftFields.foreach {
case leftField @ StructField(leftName, leftType, leftNullable, _) =>
rightMapped.get(leftName)
.map { case rightField @ StructField(_, rightType, rightNullable, _) =>
- leftField.copy(
- dataType = merge(leftType, rightType),
- nullable = leftNullable || rightNullable)
- }
- .orElse(Some(leftField))
+ leftField.copy(
+ dataType = merge(leftType, rightType),
+ nullable = leftNullable || rightNullable)
+ }
+ .orElse {
+ optionalMeta.putBoolean(metadataKeyForOptionalField, true)
+ Some(leftField.copy(metadata = optionalMeta.build()))
+ }
.foreach(newFields += _)
}
val leftMapped = fieldsMap(leftFields)
rightFields
.filterNot(f => leftMapped.get(f.name).nonEmpty)
- .foreach(newFields += _)
+ .foreach { f =>
+ optionalMeta.putBoolean(metadataKeyForOptionalField, true)
+ newFields += f.copy(metadata = optionalMeta.build())
+ }
StructType(newFields)
@@ -402,13 +425,13 @@ object StructType extends AbstractDataType {
if ((leftPrecision == rightPrecision) && (leftScale == rightScale)) {
DecimalType(leftPrecision, leftScale)
} else if ((leftPrecision != rightPrecision) && (leftScale != rightScale)) {
- throw new SparkException("Failed to merge Decimal Tpes with incompatible " +
+ throw new SparkException("Failed to merge decimal types with incompatible " +
s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale")
} else if (leftPrecision != rightPrecision) {
- throw new SparkException("Failed to merge Decimal Tpes with incompatible " +
+ throw new SparkException("Failed to merge decimal types with incompatible " +
s"precision $leftPrecision and $rightPrecision")
} else {
- throw new SparkException("Failed to merge Decimal Tpes with incompatible " +
+ throw new SparkException("Failed to merge decimal types with incompatible " +
s"scala $leftScale and $rightScale")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
index 55efea80d1a4d..7c173cbceefed 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
@@ -47,9 +47,9 @@ object RandomDataGenerator {
*/
private val PROBABILITY_OF_NULL: Float = 0.1f
- private val MAX_STR_LEN: Int = 1024
- private val MAX_ARR_SIZE: Int = 128
- private val MAX_MAP_SIZE: Int = 128
+ final val MAX_STR_LEN: Int = 1024
+ final val MAX_ARR_SIZE: Int = 128
+ final val MAX_MAP_SIZE: Int = 128
/**
* Helper function for constructing a biased random number generator which returns "interesting"
@@ -208,7 +208,17 @@ object RandomDataGenerator {
forType(valueType, nullable = valueContainsNull, rand)
) yield {
() => {
- Seq.fill(rand.nextInt(MAX_MAP_SIZE))((keyGenerator(), valueGenerator())).toMap
+ val length = rand.nextInt(MAX_MAP_SIZE)
+ val keys = scala.collection.mutable.HashSet(Seq.fill(length)(keyGenerator()): _*)
+ // In case the number of different keys is not enough, set a max iteration to avoid
+ // infinite loop.
+ var count = 0
+ while (keys.size < length && count < MAX_MAP_SIZE) {
+ keys += keyGenerator()
+ count += 1
+ }
+ val values = Seq.fill(keys.size)(valueGenerator())
+ keys.zip(values).toMap
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala
index b8ccdf7516d82..9fba7924e9542 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala
@@ -95,4 +95,15 @@ class RandomDataGeneratorSuite extends SparkFunSuite {
}
}
+ test("check size of generated map") {
+ val mapType = MapType(IntegerType, IntegerType)
+ for (seed <- 1 to 1000) {
+ val generator = RandomDataGenerator.forType(
+ mapType, nullable = false, rand = new Random(seed)).get
+ val maps = Seq.fill(100)(generator().asInstanceOf[Map[Int, Int]])
+ val expectedTotalElements = 100 / 2 * RandomDataGenerator.MAX_MAP_SIZE
+ val deviation = math.abs(maps.map(_.size).sum - expectedTotalElements)
+ assert(deviation.toDouble / expectedTotalElements < 2e-1)
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index ab680282208c8..ebf885a8fe484 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -76,6 +76,89 @@ class AnalysisSuite extends AnalysisTest {
caseSensitive = false)
}
+ test("resolve sort references - filter/limit") {
+ val a = testRelation2.output(0)
+ val b = testRelation2.output(1)
+ val c = testRelation2.output(2)
+
+ // Case 1: one missing attribute is in the leaf node and another is in the unary node
+ val plan1 = testRelation2
+ .where('a > "str").select('a, 'b)
+ .where('b > "str").select('a)
+ .sortBy('b.asc, 'c.desc)
+ val expected1 = testRelation2
+ .where(a > "str").select(a, b, c)
+ .where(b > "str").select(a, b, c)
+ .sortBy(b.asc, c.desc)
+ .select(a, b).select(a)
+ checkAnalysis(plan1, expected1)
+
+ // Case 2: all the missing attributes are in the leaf node
+ val plan2 = testRelation2
+ .where('a > "str").select('a)
+ .where('a > "str").select('a)
+ .sortBy('b.asc, 'c.desc)
+ val expected2 = testRelation2
+ .where(a > "str").select(a, b, c)
+ .where(a > "str").select(a, b, c)
+ .sortBy(b.asc, c.desc)
+ .select(a)
+ checkAnalysis(plan2, expected2)
+ }
+
+ test("resolve sort references - join") {
+ val a = testRelation2.output(0)
+ val b = testRelation2.output(1)
+ val c = testRelation2.output(2)
+ val h = testRelation3.output(3)
+
+ // Case: join itself can resolve all the missing attributes
+ val plan = testRelation2.join(testRelation3)
+ .where('a > "str").select('a, 'b)
+ .sortBy('c.desc, 'h.asc)
+ val expected = testRelation2.join(testRelation3)
+ .where(a > "str").select(a, b, c, h)
+ .sortBy(c.desc, h.asc)
+ .select(a, b)
+ checkAnalysis(plan, expected)
+ }
+
+ test("resolve sort references - aggregate") {
+ val a = testRelation2.output(0)
+ val b = testRelation2.output(1)
+ val c = testRelation2.output(2)
+ val alias_a3 = count(a).as("a3")
+ val alias_b = b.as("aggOrder")
+
+ // Case 1: when the child of Sort is not Aggregate,
+ // the sort reference is handled by the rule ResolveSortReferences
+ val plan1 = testRelation2
+ .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3"))
+ .select('a, 'c, 'a3)
+ .orderBy('b.asc)
+
+ val expected1 = testRelation2
+ .groupBy(a, c, b)(a, c, alias_a3, b)
+ .select(a, c, alias_a3.toAttribute, b)
+ .orderBy(b.asc)
+ .select(a, c, alias_a3.toAttribute)
+
+ checkAnalysis(plan1, expected1)
+
+ // Case 2: when the child of Sort is Aggregate,
+ // the sort reference is handled by the rule ResolveAggregateFunctions
+ val plan2 = testRelation2
+ .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3"))
+ .orderBy('b.asc)
+
+ val expected2 = testRelation2
+ .groupBy(a, c, b)(a, c, alias_a3, alias_b)
+ .orderBy(alias_b.toAttribute.asc)
+ .select(a, c, alias_a3.toAttribute)
+
+ checkAnalysis(plan2, expected2)
+ }
+
test("resolve relations") {
assertAnalysisError(
UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq("Table not found: tAbLe"))
@@ -154,6 +237,11 @@ class AnalysisSuite extends AnalysisTest {
checkAnalysis(plan, expected)
}
+ test("self intersect should resolve duplicate expression IDs") {
+ val plan = testRelation.intersect(testRelation)
+ assertAnalysisSuccess(plan)
+ }
+
test("SPARK-8654: invalid CAST in NULL IN(...) expression") {
val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil,
LocalRelation()
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
new file mode 100644
index 0000000000000..fcf4ac1967a53
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+
+class ResolveNaturalJoinSuite extends AnalysisTest {
+ lazy val a = 'a.string
+ lazy val b = 'b.string
+ lazy val c = 'c.string
+ lazy val aNotNull = a.notNull
+ lazy val bNotNull = b.notNull
+ lazy val cNotNull = c.notNull
+ lazy val r1 = LocalRelation(b, a)
+ lazy val r2 = LocalRelation(c, a)
+ lazy val r3 = LocalRelation(aNotNull, bNotNull)
+ lazy val r4 = LocalRelation(cNotNull, bNotNull)
+
+ test("natural inner join") {
+ val plan = r1.join(r2, NaturalJoin(Inner), None)
+ val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c)
+ checkAnalysis(plan, expected)
+ }
+
+ test("natural left join") {
+ val plan = r1.join(r2, NaturalJoin(LeftOuter), None)
+ val expected = r1.join(r2, LeftOuter, Some(EqualTo(a, a))).select(a, b, c)
+ checkAnalysis(plan, expected)
+ }
+
+ test("natural right join") {
+ val plan = r1.join(r2, NaturalJoin(RightOuter), None)
+ val expected = r1.join(r2, RightOuter, Some(EqualTo(a, a))).select(a, b, c)
+ checkAnalysis(plan, expected)
+ }
+
+ test("natural full outer join") {
+ val plan = r1.join(r2, NaturalJoin(FullOuter), None)
+ val expected = r1.join(r2, FullOuter, Some(EqualTo(a, a))).select(
+ Alias(Coalesce(Seq(a, a)), "a")(), b, c)
+ checkAnalysis(plan, expected)
+ }
+
+ test("natural inner join with no nullability") {
+ val plan = r3.join(r4, NaturalJoin(Inner), None)
+ val expected = r3.join(r4, Inner, Some(EqualTo(bNotNull, bNotNull))).select(
+ bNotNull, aNotNull, cNotNull)
+ checkAnalysis(plan, expected)
+ }
+
+ test("natural left join with no nullability") {
+ val plan = r3.join(r4, NaturalJoin(LeftOuter), None)
+ val expected = r3.join(r4, LeftOuter, Some(EqualTo(bNotNull, bNotNull))).select(
+ bNotNull, aNotNull, c)
+ checkAnalysis(plan, expected)
+ }
+
+ test("natural right join with no nullability") {
+ val plan = r3.join(r4, NaturalJoin(RightOuter), None)
+ val expected = r3.join(r4, RightOuter, Some(EqualTo(bNotNull, bNotNull))).select(
+ bNotNull, a, cNotNull)
+ checkAnalysis(plan, expected)
+ }
+
+ test("natural full outer join with no nullability") {
+ val plan = r3.join(r4, NaturalJoin(FullOuter), None)
+ val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select(
+ Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c)
+ checkAnalysis(plan, expected)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
index bc07b609a3413..3741a6ba95a86 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
@@ -31,6 +31,12 @@ object TestRelations {
AttributeReference("d", DecimalType(10, 2))(),
AttributeReference("e", ShortType)())
+ val testRelation3 = LocalRelation(
+ AttributeReference("e", ShortType)(),
+ AttributeReference("f", StringType)(),
+ AttributeReference("g", DoubleType)(),
+ AttributeReference("h", DecimalType(10, 2))())
+
val nestedRelation = LocalRelation(
AttributeReference("top", StructType(
StructField("duplicateField", StringType) ::
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala
new file mode 100644
index 0000000000000..45c5ceecb0eef
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala
@@ -0,0 +1,453 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.catalog
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
+
+
+/**
+ * A reasonable complete test suite (i.e. behaviors) for a [[Catalog]].
+ *
+ * Implementations of the [[Catalog]] interface can create test suites by extending this.
+ */
+abstract class CatalogTestCases extends SparkFunSuite {
+ private val storageFormat = StorageFormat("usa", "$", "zzz", "serde", Map())
+ private val part1 = TablePartition(Map("a" -> "1"), storageFormat)
+ private val part2 = TablePartition(Map("b" -> "2"), storageFormat)
+ private val part3 = TablePartition(Map("c" -> "3"), storageFormat)
+ private val funcClass = "org.apache.spark.myFunc"
+
+ protected def newEmptyCatalog(): Catalog
+
+ /**
+ * Creates a basic catalog, with the following structure:
+ *
+ * db1
+ * db2
+ * - tbl1
+ * - tbl2
+ * - part1
+ * - part2
+ * - func1
+ */
+ private def newBasicCatalog(): Catalog = {
+ val catalog = newEmptyCatalog()
+ catalog.createDatabase(newDb("db1"), ignoreIfExists = false)
+ catalog.createDatabase(newDb("db2"), ignoreIfExists = false)
+ catalog.createTable("db2", newTable("tbl1"), ignoreIfExists = false)
+ catalog.createTable("db2", newTable("tbl2"), ignoreIfExists = false)
+ catalog.createPartitions("db2", "tbl2", Seq(part1, part2), ignoreIfExists = false)
+ catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false)
+ catalog
+ }
+
+ private def newFunc(): Function = Function("funcname", funcClass)
+
+ private def newDb(name: String = "default"): Database =
+ Database(name, name + " description", "uri", Map.empty)
+
+ private def newTable(name: String): Table =
+ Table(name, "", Seq.empty, Seq.empty, Seq.empty, null, 0, Map.empty, "EXTERNAL_TABLE", 0, 0,
+ None, None)
+
+ private def newFunc(name: String): Function = Function(name, funcClass)
+
+ // --------------------------------------------------------------------------
+ // Databases
+ // --------------------------------------------------------------------------
+
+ test("basic create, drop and list databases") {
+ val catalog = newEmptyCatalog()
+ catalog.createDatabase(newDb(), ignoreIfExists = false)
+ assert(catalog.listDatabases().toSet == Set("default"))
+
+ catalog.createDatabase(newDb("default2"), ignoreIfExists = false)
+ assert(catalog.listDatabases().toSet == Set("default", "default2"))
+ }
+
+ test("get database when a database exists") {
+ val db1 = newBasicCatalog().getDatabase("db1")
+ assert(db1.name == "db1")
+ assert(db1.description.contains("db1"))
+ }
+
+ test("get database should throw exception when the database does not exist") {
+ intercept[AnalysisException] { newBasicCatalog().getDatabase("db_that_does_not_exist") }
+ }
+
+ test("list databases without pattern") {
+ val catalog = newBasicCatalog()
+ assert(catalog.listDatabases().toSet == Set("db1", "db2"))
+ }
+
+ test("list databases with pattern") {
+ val catalog = newBasicCatalog()
+ assert(catalog.listDatabases("db").toSet == Set.empty)
+ assert(catalog.listDatabases("db*").toSet == Set("db1", "db2"))
+ assert(catalog.listDatabases("*1").toSet == Set("db1"))
+ assert(catalog.listDatabases("db2").toSet == Set("db2"))
+ }
+
+ test("drop database") {
+ val catalog = newBasicCatalog()
+ catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = false)
+ assert(catalog.listDatabases().toSet == Set("db2"))
+ }
+
+ test("drop database when the database is not empty") {
+ // Throw exception if there are functions left
+ val catalog1 = newBasicCatalog()
+ catalog1.dropTable("db2", "tbl1", ignoreIfNotExists = false)
+ catalog1.dropTable("db2", "tbl2", ignoreIfNotExists = false)
+ intercept[AnalysisException] {
+ catalog1.dropDatabase("db2", ignoreIfNotExists = false, cascade = false)
+ }
+
+ // Throw exception if there are tables left
+ val catalog2 = newBasicCatalog()
+ catalog2.dropFunction("db2", "func1")
+ intercept[AnalysisException] {
+ catalog2.dropDatabase("db2", ignoreIfNotExists = false, cascade = false)
+ }
+
+ // When cascade is true, it should drop them
+ val catalog3 = newBasicCatalog()
+ catalog3.dropDatabase("db2", ignoreIfNotExists = false, cascade = true)
+ assert(catalog3.listDatabases().toSet == Set("db1"))
+ }
+
+ test("drop database when the database does not exist") {
+ val catalog = newBasicCatalog()
+
+ intercept[AnalysisException] {
+ catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = false, cascade = false)
+ }
+
+ catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = true, cascade = false)
+ }
+
+ test("alter database") {
+ val catalog = newBasicCatalog()
+ catalog.alterDatabase("db1", Database("db1", "new description", "lll", Map.empty))
+ assert(catalog.getDatabase("db1").description == "new description")
+ }
+
+ test("alter database should throw exception when the database does not exist") {
+ intercept[AnalysisException] {
+ newBasicCatalog().alterDatabase("no_db", Database("no_db", "ddd", "lll", Map.empty))
+ }
+ }
+
+ // --------------------------------------------------------------------------
+ // Tables
+ // --------------------------------------------------------------------------
+
+ test("drop table") {
+ val catalog = newBasicCatalog()
+ assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2"))
+ catalog.dropTable("db2", "tbl1", ignoreIfNotExists = false)
+ assert(catalog.listTables("db2").toSet == Set("tbl2"))
+ }
+
+ test("drop table when database / table does not exist") {
+ val catalog = newBasicCatalog()
+
+ // Should always throw exception when the database does not exist
+ intercept[AnalysisException] {
+ catalog.dropTable("unknown_db", "unknown_table", ignoreIfNotExists = false)
+ }
+
+ intercept[AnalysisException] {
+ catalog.dropTable("unknown_db", "unknown_table", ignoreIfNotExists = true)
+ }
+
+ // Should throw exception when the table does not exist, if ignoreIfNotExists is false
+ intercept[AnalysisException] {
+ catalog.dropTable("db2", "unknown_table", ignoreIfNotExists = false)
+ }
+
+ catalog.dropTable("db2", "unknown_table", ignoreIfNotExists = true)
+ }
+
+ test("rename table") {
+ val catalog = newBasicCatalog()
+
+ assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2"))
+ catalog.renameTable("db2", "tbl1", "tblone")
+ assert(catalog.listTables("db2").toSet == Set("tblone", "tbl2"))
+ }
+
+ test("rename table when database / table does not exist") {
+ val catalog = newBasicCatalog()
+
+ intercept[AnalysisException] { // Throw exception when the database does not exist
+ catalog.renameTable("unknown_db", "unknown_table", "unknown_table")
+ }
+
+ intercept[AnalysisException] { // Throw exception when the table does not exist
+ catalog.renameTable("db2", "unknown_table", "unknown_table")
+ }
+ }
+
+ test("alter table") {
+ val catalog = newBasicCatalog()
+ catalog.alterTable("db2", "tbl1", newTable("tbl1").copy(createTime = 10))
+ assert(catalog.getTable("db2", "tbl1").createTime == 10)
+ }
+
+ test("alter table when database / table does not exist") {
+ val catalog = newBasicCatalog()
+
+ intercept[AnalysisException] { // Throw exception when the database does not exist
+ catalog.alterTable("unknown_db", "unknown_table", newTable("unknown_table"))
+ }
+
+ intercept[AnalysisException] { // Throw exception when the table does not exist
+ catalog.alterTable("db2", "unknown_table", newTable("unknown_table"))
+ }
+ }
+
+ test("get table") {
+ assert(newBasicCatalog().getTable("db2", "tbl1").name == "tbl1")
+ }
+
+ test("get table when database / table does not exist") {
+ val catalog = newBasicCatalog()
+ intercept[AnalysisException] {
+ catalog.getTable("unknown_db", "unknown_table")
+ }
+
+ intercept[AnalysisException] {
+ catalog.getTable("db2", "unknown_table")
+ }
+ }
+
+ test("list tables without pattern") {
+ val catalog = newBasicCatalog()
+ assert(catalog.listTables("db1").toSet == Set.empty)
+ assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2"))
+ }
+
+ test("list tables with pattern") {
+ val catalog = newBasicCatalog()
+
+ // Test when database does not exist
+ intercept[AnalysisException] { catalog.listTables("unknown_db") }
+
+ assert(catalog.listTables("db1", "*").toSet == Set.empty)
+ assert(catalog.listTables("db2", "*").toSet == Set("tbl1", "tbl2"))
+ assert(catalog.listTables("db2", "tbl*").toSet == Set("tbl1", "tbl2"))
+ assert(catalog.listTables("db2", "*1").toSet == Set("tbl1"))
+ }
+
+ // --------------------------------------------------------------------------
+ // Partitions
+ // --------------------------------------------------------------------------
+
+ test("basic create and list partitions") {
+ val catalog = newEmptyCatalog()
+ catalog.createDatabase(newDb("mydb"), ignoreIfExists = false)
+ catalog.createTable("mydb", newTable("mytbl"), ignoreIfExists = false)
+ catalog.createPartitions("mydb", "mytbl", Seq(part1, part2), ignoreIfExists = false)
+ assert(catalog.listPartitions("mydb", "mytbl").toSet == Set(part1, part2))
+ }
+
+ test("create partitions when database / table does not exist") {
+ val catalog = newBasicCatalog()
+ intercept[AnalysisException] {
+ catalog.createPartitions("does_not_exist", "tbl1", Seq(), ignoreIfExists = false)
+ }
+ intercept[AnalysisException] {
+ catalog.createPartitions("db2", "does_not_exist", Seq(), ignoreIfExists = false)
+ }
+ }
+
+ test("create partitions that already exist") {
+ val catalog = newBasicCatalog()
+ intercept[AnalysisException] {
+ catalog.createPartitions("db2", "tbl2", Seq(part1), ignoreIfExists = false)
+ }
+ catalog.createPartitions("db2", "tbl2", Seq(part1), ignoreIfExists = true)
+ }
+
+ test("drop partitions") {
+ val catalog = newBasicCatalog()
+ assert(catalog.listPartitions("db2", "tbl2").toSet == Set(part1, part2))
+ catalog.dropPartitions("db2", "tbl2", Seq(part1.spec), ignoreIfNotExists = false)
+ assert(catalog.listPartitions("db2", "tbl2").toSet == Set(part2))
+ val catalog2 = newBasicCatalog()
+ assert(catalog2.listPartitions("db2", "tbl2").toSet == Set(part1, part2))
+ catalog2.dropPartitions("db2", "tbl2", Seq(part1.spec, part2.spec), ignoreIfNotExists = false)
+ assert(catalog2.listPartitions("db2", "tbl2").isEmpty)
+ }
+
+ test("drop partitions when database / table does not exist") {
+ val catalog = newBasicCatalog()
+ intercept[AnalysisException] {
+ catalog.dropPartitions("does_not_exist", "tbl1", Seq(), ignoreIfNotExists = false)
+ }
+ intercept[AnalysisException] {
+ catalog.dropPartitions("db2", "does_not_exist", Seq(), ignoreIfNotExists = false)
+ }
+ }
+
+ test("drop partitions that do not exist") {
+ val catalog = newBasicCatalog()
+ intercept[AnalysisException] {
+ catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = false)
+ }
+ catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = true)
+ }
+
+ test("get partition") {
+ val catalog = newBasicCatalog()
+ assert(catalog.getPartition("db2", "tbl2", part1.spec) == part1)
+ assert(catalog.getPartition("db2", "tbl2", part2.spec) == part2)
+ intercept[AnalysisException] {
+ catalog.getPartition("db2", "tbl1", part3.spec)
+ }
+ }
+
+ test("get partition when database / table does not exist") {
+ val catalog = newBasicCatalog()
+ intercept[AnalysisException] {
+ catalog.getPartition("does_not_exist", "tbl1", part1.spec)
+ }
+ intercept[AnalysisException] {
+ catalog.getPartition("db2", "does_not_exist", part1.spec)
+ }
+ }
+
+ test("alter partitions") {
+ val catalog = newBasicCatalog()
+ val partSameSpec = part1.copy(storage = storageFormat.copy(serde = "myserde"))
+ val partNewSpec = part1.copy(spec = Map("x" -> "10"))
+ // alter but keep spec the same
+ catalog.alterPartition("db2", "tbl2", part1.spec, partSameSpec)
+ assert(catalog.getPartition("db2", "tbl2", part1.spec) == partSameSpec)
+ // alter and change spec
+ catalog.alterPartition("db2", "tbl2", part1.spec, partNewSpec)
+ intercept[AnalysisException] {
+ catalog.getPartition("db2", "tbl2", part1.spec)
+ }
+ assert(catalog.getPartition("db2", "tbl2", partNewSpec.spec) == partNewSpec)
+ }
+
+ test("alter partition when database / table does not exist") {
+ val catalog = newBasicCatalog()
+ intercept[AnalysisException] {
+ catalog.alterPartition("does_not_exist", "tbl1", part1.spec, part1)
+ }
+ intercept[AnalysisException] {
+ catalog.alterPartition("db2", "does_not_exist", part1.spec, part1)
+ }
+ }
+
+ // --------------------------------------------------------------------------
+ // Functions
+ // --------------------------------------------------------------------------
+
+ test("basic create and list functions") {
+ val catalog = newEmptyCatalog()
+ catalog.createDatabase(newDb("mydb"), ignoreIfExists = false)
+ catalog.createFunction("mydb", newFunc("myfunc"), ignoreIfExists = false)
+ assert(catalog.listFunctions("mydb", "*").toSet == Set("myfunc"))
+ }
+
+ test("create function when database does not exist") {
+ val catalog = newBasicCatalog()
+ intercept[AnalysisException] {
+ catalog.createFunction("does_not_exist", newFunc(), ignoreIfExists = false)
+ }
+ }
+
+ test("create function that already exists") {
+ val catalog = newBasicCatalog()
+ intercept[AnalysisException] {
+ catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = false)
+ }
+ catalog.createFunction("db2", newFunc("func1"), ignoreIfExists = true)
+ }
+
+ test("drop function") {
+ val catalog = newBasicCatalog()
+ assert(catalog.listFunctions("db2", "*").toSet == Set("func1"))
+ catalog.dropFunction("db2", "func1")
+ assert(catalog.listFunctions("db2", "*").isEmpty)
+ }
+
+ test("drop function when database does not exist") {
+ val catalog = newBasicCatalog()
+ intercept[AnalysisException] {
+ catalog.dropFunction("does_not_exist", "something")
+ }
+ }
+
+ test("drop function that does not exist") {
+ val catalog = newBasicCatalog()
+ intercept[AnalysisException] {
+ catalog.dropFunction("db2", "does_not_exist")
+ }
+ }
+
+ test("get function") {
+ val catalog = newBasicCatalog()
+ assert(catalog.getFunction("db2", "func1") == newFunc("func1"))
+ intercept[AnalysisException] {
+ catalog.getFunction("db2", "does_not_exist")
+ }
+ }
+
+ test("get function when database does not exist") {
+ val catalog = newBasicCatalog()
+ intercept[AnalysisException] {
+ catalog.getFunction("does_not_exist", "func1")
+ }
+ }
+
+ test("alter function") {
+ val catalog = newBasicCatalog()
+ assert(catalog.getFunction("db2", "func1").className == funcClass)
+ // alter func but keep name
+ catalog.alterFunction("db2", "func1", newFunc("func1").copy(className = "muhaha"))
+ assert(catalog.getFunction("db2", "func1").className == "muhaha")
+ // alter func and change name
+ catalog.alterFunction("db2", "func1", newFunc("funcky"))
+ intercept[AnalysisException] {
+ catalog.getFunction("db2", "func1")
+ }
+ assert(catalog.getFunction("db2", "funcky").className == funcClass)
+ }
+
+ test("alter function when database does not exist") {
+ val catalog = newBasicCatalog()
+ intercept[AnalysisException] {
+ catalog.alterFunction("does_not_exist", "func1", newFunc())
+ }
+ }
+
+ test("list functions") {
+ val catalog = newBasicCatalog()
+ catalog.createFunction("db2", newFunc("func2"), ignoreIfExists = false)
+ catalog.createFunction("db2", newFunc("not_me"), ignoreIfExists = false)
+ assert(catalog.listFunctions("db2", "*").toSet == Set("func1", "func2", "not_me"))
+ assert(catalog.listFunctions("db2", "func*").toSet == Set("func1", "func2"))
+ }
+
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala
new file mode 100644
index 0000000000000..871f0a0f46a22
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.catalog
+
+/** Test suite for the [[InMemoryCatalog]]. */
+class InMemoryCatalogSuite extends CatalogTestCases {
+ override protected def newEmptyCatalog(): Catalog = new InMemoryCatalog
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index bc36a55ae0ea2..8b02b63c6cf3a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -21,9 +21,11 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
case class StringLongClass(a: String, b: Long)
@@ -32,94 +34,49 @@ case class StringIntClass(a: String, b: Int)
case class ComplexClass(a: Long, b: StringLongClass)
class EncoderResolutionSuite extends PlanTest {
+ private val str = UTF8String.fromString("hello")
+
test("real type doesn't match encoder schema but they are compatible: product") {
val encoder = ExpressionEncoder[StringLongClass]
- val cls = classOf[StringLongClass]
-
- {
- val attrs = Seq('a.string, 'b.int)
- val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
- val expected: Expression = NewInstance(
- cls,
- Seq(
- toExternalString('a.string),
- AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
- ),
- ObjectType(cls),
- propagateNull = false)
- compareExpressions(fromRowExpr, expected)
- }
+ // int type can be up cast to long type
+ val attrs1 = Seq('a.string, 'b.int)
+ encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1))
- {
- val attrs = Seq('a.int, 'b.long)
- val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
- val expected = NewInstance(
- cls,
- Seq(
- toExternalString('a.int.cast(StringType)),
- AssertNotNull('b.long, cls.getName, "b", "Long")
- ),
- ObjectType(cls),
- propagateNull = false)
- compareExpressions(fromRowExpr, expected)
- }
+ // int type can be up cast to string type
+ val attrs2 = Seq('a.int, 'b.long)
+ encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L))
}
test("real type doesn't match encoder schema but they are compatible: nested product") {
val encoder = ExpressionEncoder[ComplexClass]
- val innerCls = classOf[StringLongClass]
- val cls = classOf[ComplexClass]
-
val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
- val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
- val expected: Expression = NewInstance(
- cls,
- Seq(
- AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"),
- If(
- 'b.struct('a.int, 'b.long).isNull,
- Literal.create(null, ObjectType(innerCls)),
- NewInstance(
- innerCls,
- Seq(
- toExternalString(
- GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)),
- AssertNotNull(
- GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
- innerCls.getName, "b", "Long")),
- ObjectType(innerCls),
- propagateNull = false)
- )),
- ObjectType(cls),
- propagateNull = false)
- compareExpressions(fromRowExpr, expected)
+ encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
}
test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
val encoder = ExpressionEncoder.tuple(
ExpressionEncoder[StringLongClass],
ExpressionEncoder[Long])
- val cls = classOf[StringLongClass]
-
val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
- val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
- val expected: Expression = NewInstance(
- classOf[Tuple2[_, _]],
- Seq(
- NewInstance(
- cls,
- Seq(
- toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))),
- AssertNotNull(
- GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType),
- cls.getName, "b", "Long")),
- ObjectType(cls),
- propagateNull = false),
- 'b.int.cast(LongType)),
- ObjectType(classOf[Tuple2[_, _]]),
- propagateNull = false)
- compareExpressions(fromRowExpr, expected)
+ encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
+ }
+
+ test("nullability of array type element should not fail analysis") {
+ val encoder = ExpressionEncoder[Seq[Int]]
+ val attrs = 'a.array(IntegerType) :: Nil
+
+ // It should pass analysis
+ val bound = encoder.resolve(attrs, null).bind(attrs)
+
+ // If no null values appear, it should works fine
+ bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
+
+ // If there is null value, it should throw runtime exception
+ val e = intercept[RuntimeException] {
+ bound.fromRow(InternalRow(new GenericArrayData(Array(1, null))))
+ }
+ assert(e.getMessage.contains("Null value appeared in non-nullable field"))
}
test("the real number of fields doesn't match encoder schema: tuple encoder") {
@@ -127,7 +84,7 @@ class EncoderResolutionSuite extends PlanTest {
{
val attrs = Seq('a.string, 'b.long, 'c.int)
- assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+ assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
"Try to map struct to Tuple2, " +
"but failed as the number of fields does not line up.\n" +
" - Input schema: struct\n" +
@@ -136,7 +93,7 @@ class EncoderResolutionSuite extends PlanTest {
{
val attrs = Seq('a.string)
- assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+ assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
"Try to map struct to Tuple2, " +
"but failed as the number of fields does not line up.\n" +
" - Input schema: struct\n" +
@@ -149,7 +106,7 @@ class EncoderResolutionSuite extends PlanTest {
{
val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int))
- assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+ assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
"Try to map struct to Tuple2, " +
"but failed as the number of fields does not line up.\n" +
" - Input schema: struct>\n" +
@@ -158,7 +115,7 @@ class EncoderResolutionSuite extends PlanTest {
{
val attrs = Seq('a.string, 'b.struct('x.long))
- assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message ==
+ assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
"Try to map struct to Tuple2, " +
"but failed as the number of fields does not line up.\n" +
" - Input schema: struct>\n" +
@@ -166,10 +123,6 @@ class EncoderResolutionSuite extends PlanTest {
}
}
- private def toExternalString(e: Expression): Expression = {
- Invoke(e, "toString", ObjectType(classOf[String]), Nil)
- }
-
test("throw exception if real type is not compatible with encoder schema") {
val msg1 = intercept[AnalysisException] {
ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 88c558d80a79a..e00060f9b6aff 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -19,13 +19,10 @@ package org.apache.spark.sql.catalyst.encoders
import java.sql.{Date, Timestamp}
import java.util.Arrays
-import java.util.concurrent.ConcurrentMap
import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag
-import com.google.common.collect.MapMaker
-
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
@@ -78,7 +75,7 @@ class JavaSerializable(val value: Int) extends Serializable {
}
class ExpressionEncoderSuite extends SparkFunSuite {
- OuterScopes.outerScopes.put(getClass.getName, this)
+ OuterScopes.addOuterScope(this)
implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder()
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 0c42e2fc7c5e5..b5413fbe2bbcc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -36,7 +36,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
import scala.concurrent.duration._
val futures = (1 to 20).map { _ =>
- future {
+ Future {
GeneratePredicate.generate(EqualTo(Literal(1), Literal(1)))
GenerateMutableProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil)
GenerateOrdering.generate(Add(Literal(1), Literal(1)).asc :: Nil)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
index 37148a226f293..a4a12c0d62e92 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
@@ -28,21 +28,9 @@ class AggregateOptimizeSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Aggregate", FixedPoint(100),
- ReplaceDistinctWithAggregate,
RemoveLiteralFromGroupExpressions) :: Nil
}
- test("replace distinct with aggregate") {
- val input = LocalRelation('a.int, 'b.int)
-
- val query = Distinct(input)
- val optimized = Optimize.execute(query.analyze)
-
- val correctAnswer = Aggregate(input.output, input.output, input)
-
- comparePlans(optimized, correctAnswer)
- }
-
test("remove literals in grouping expression") {
val input = LocalRelation('a.int, 'b.int)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
similarity index 96%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
index 85b6530481b03..f5fd5ca6beb15 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
@@ -25,11 +25,11 @@ import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-class ProjectCollapsingSuite extends PlanTest {
+class CollapseProjectSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", FixedPoint(10), EliminateSubQueries) ::
- Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil
+ Batch("CollapseProject", Once, CollapseProject) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index f9f3bd55aa578..b49ca928b6292 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -42,7 +42,7 @@ class FilterPushdownSuite extends PlanTest {
PushPredicateThroughGenerate,
PushPredicateThroughAggregate,
ColumnPruning,
- ProjectCollapsing) :: Nil
+ CollapseProject) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala
index 9b1e16c727647..858a0d8fde3ea 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala
@@ -43,7 +43,7 @@ class JoinOrderSuite extends PlanTest {
PushPredicateThroughGenerate,
PushPredicateThroughAggregate,
ColumnPruning,
- ProjectCollapsing) :: Nil
+ CollapseProject) :: Nil
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
new file mode 100644
index 0000000000000..f8ae5d9be2084
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+class ReplaceOperatorSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Replace Operators", FixedPoint(100),
+ ReplaceDistinctWithAggregate,
+ ReplaceIntersectWithSemiJoin) :: Nil
+ }
+
+ test("replace Intersect with Left-semi Join") {
+ val table1 = LocalRelation('a.int, 'b.int)
+ val table2 = LocalRelation('c.int, 'd.int)
+
+ val query = Intersect(table1, table2)
+ val optimized = Optimize.execute(query.analyze)
+
+ val correctAnswer =
+ Aggregate(table1.output, table1.output,
+ Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("replace Distinct with Aggregate") {
+ val input = LocalRelation('a.int, 'b.int)
+
+ val query = Distinct(input)
+ val optimized = Optimize.execute(query.analyze)
+
+ val correctAnswer = Aggregate(input.output, input.output, input)
+
+ comparePlans(optimized, correctAnswer)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
index 2283f7c008ba2..b8ea32b4dfe01 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
@@ -39,7 +39,6 @@ class SetOperationSuite extends PlanTest {
val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
val testRelation3 = LocalRelation('g.int, 'h.int, 'i.int)
val testUnion = Union(testRelation :: testRelation2 :: testRelation3 :: Nil)
- val testIntersect = Intersect(testRelation, testRelation2)
val testExcept = Except(testRelation, testRelation2)
test("union: combine unions into one unions") {
@@ -57,19 +56,12 @@ class SetOperationSuite extends PlanTest {
comparePlans(combinedUnionsOptimized, unionOptimized3)
}
- test("intersect/except: filter to each side") {
- val intersectQuery = testIntersect.where('b < 10)
+ test("except: filter to each side") {
val exceptQuery = testExcept.where('c >= 5)
-
- val intersectOptimized = Optimize.execute(intersectQuery.analyze)
val exceptOptimized = Optimize.execute(exceptQuery.analyze)
-
- val intersectCorrectAnswer =
- Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze
val exceptCorrectAnswer =
Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze
- comparePlans(intersectOptimized, intersectCorrectAnswer)
comparePlans(exceptOptimized, exceptCorrectAnswer)
}
@@ -95,13 +87,8 @@ class SetOperationSuite extends PlanTest {
}
test("SPARK-10539: Project should not be pushed down through Intersect or Except") {
- val intersectQuery = testIntersect.select('b, 'c)
val exceptQuery = testExcept.select('a, 'b, 'c)
-
- val intersectOptimized = Optimize.execute(intersectQuery.analyze)
val exceptOptimized = Optimize.execute(exceptQuery.analyze)
-
- comparePlans(intersectOptimized, intersectQuery.analyze)
comparePlans(exceptOptimized, exceptQuery.analyze)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala
new file mode 100644
index 0000000000000..8b05f9e33d69e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import org.apache.spark.SparkFunSuite
+
+class ASTNodeSuite extends SparkFunSuite {
+ test("SPARK-13157 - remainder must return all input chars") {
+ val inputs = Seq(
+ ("add jar", "file:///tmp/ab/TestUDTF.jar"),
+ ("add jar", "file:///tmp/a@b/TestUDTF.jar"),
+ ("add jar", "c:\\windows32\\TestUDTF.jar"),
+ ("add jar", "some \nbad\t\tfile\r\n.\njar"),
+ ("ADD JAR", "@*#&@(!#@$^*!@^@#(*!@#"),
+ ("SET", "foo=bar"),
+ ("SET", "foo*)(@#^*@&!#^=bar")
+ )
+ inputs.foreach {
+ case (command, arguments) =>
+ val node = ParseDriver.parsePlan(s"$command $arguments", null)
+ assert(node.remainder === arguments)
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
new file mode 100644
index 0000000000000..b5cf91394d910
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+
+class ConstraintPropagationSuite extends SparkFunSuite {
+
+ private def resolveColumn(tr: LocalRelation, columnName: String): Expression =
+ tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get
+
+ private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = {
+ val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _))
+ val extra = found.filterNot(i => expected.map(_.semanticEquals(i)).reduce(_ || _))
+ if (missing.nonEmpty || extra.nonEmpty) {
+ fail(
+ s"""
+ |== FAIL: Constraints do not match ===
+ |Found: ${found.mkString(",")}
+ |Expected: ${expected.mkString(",")}
+ |== Result ==
+ |Missing: ${if (missing.isEmpty) "N/A" else missing.mkString(",")}
+ |Found but not expected: ${if (extra.isEmpty) "N/A" else extra.mkString(",")}
+ """.stripMargin)
+ }
+ }
+
+ test("propagating constraints in filters") {
+ val tr = LocalRelation('a.int, 'b.string, 'c.int)
+
+ assert(tr.analyze.constraints.isEmpty)
+
+ assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty)
+
+ verifyConstraints(tr
+ .where('a.attr > 10)
+ .analyze.constraints,
+ Set(resolveColumn(tr, "a") > 10,
+ IsNotNull(resolveColumn(tr, "a"))))
+
+ verifyConstraints(tr
+ .where('a.attr > 10)
+ .select('c.attr, 'a.attr)
+ .where('c.attr < 100)
+ .analyze.constraints,
+ Set(resolveColumn(tr, "a") > 10,
+ resolveColumn(tr, "c") < 100,
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "c"))))
+ }
+
+ test("propagating constraints in union") {
+ val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
+ val tr2 = LocalRelation('d.int, 'e.int, 'f.int)
+ val tr3 = LocalRelation('g.int, 'h.int, 'i.int)
+
+ assert(tr1
+ .where('a.attr > 10)
+ .unionAll(tr2.where('e.attr > 10)
+ .unionAll(tr3.where('i.attr > 10)))
+ .analyze.constraints.isEmpty)
+
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .unionAll(tr2.where('d.attr > 10)
+ .unionAll(tr3.where('g.attr > 10)))
+ .analyze.constraints,
+ Set(resolveColumn(tr1, "a") > 10,
+ IsNotNull(resolveColumn(tr1, "a"))))
+ }
+
+ test("propagating constraints in intersect") {
+ val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
+ val tr2 = LocalRelation('a.int, 'b.int, 'c.int)
+
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .intersect(tr2.where('b.attr < 100))
+ .analyze.constraints,
+ Set(resolveColumn(tr1, "a") > 10,
+ resolveColumn(tr1, "b") < 100,
+ IsNotNull(resolveColumn(tr1, "a")),
+ IsNotNull(resolveColumn(tr1, "b"))))
+ }
+
+ test("propagating constraints in except") {
+ val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
+ val tr2 = LocalRelation('a.int, 'b.int, 'c.int)
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .except(tr2.where('b.attr < 100))
+ .analyze.constraints,
+ Set(resolveColumn(tr1, "a") > 10,
+ IsNotNull(resolveColumn(tr1, "a"))))
+ }
+
+ test("propagating constraints in inner join") {
+ val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
+ val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr))
+ .analyze.constraints,
+ Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
+ tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
+ tr1.resolveQuoted("a", caseInsensitiveResolution).get ===
+ tr2.resolveQuoted("a", caseInsensitiveResolution).get,
+ IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get),
+ IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get),
+ IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))
+ }
+
+ test("propagating constraints in left-semi join") {
+ val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
+ val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .join(tr2.where('d.attr < 100), LeftSemi, Some("tr1.a".attr === "tr2.a".attr))
+ .analyze.constraints,
+ Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
+ IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
+ }
+
+ test("propagating constraints in left-outer join") {
+ val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
+ val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr))
+ .analyze.constraints,
+ Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
+ IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
+ }
+
+ test("propagating constraints in right-outer join") {
+ val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
+ val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr))
+ .analyze.constraints,
+ Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
+ IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))
+ }
+
+ test("propagating constraints in full-outer join") {
+ val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1)
+ val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2)
+ assert(tr1.where('a.attr > 10)
+ .join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr))
+ .analyze.constraints.isEmpty)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index 706ecd29d1355..c2bbca7c33f28 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -122,7 +122,9 @@ class DataTypeSuite extends SparkFunSuite {
val right = StructType(List())
val merged = left.merge(right)
- assert(merged === left)
+ assert(DataType.equalsIgnoreCompatibleNullability(merged, left))
+ assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+ assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField))
}
test("merge where left is empty") {
@@ -135,8 +137,9 @@ class DataTypeSuite extends SparkFunSuite {
val merged = left.merge(right)
- assert(right === merged)
-
+ assert(DataType.equalsIgnoreCompatibleNullability(merged, right))
+ assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+ assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField))
}
test("merge where both are non-empty") {
@@ -154,7 +157,10 @@ class DataTypeSuite extends SparkFunSuite {
val merged = left.merge(right)
- assert(merged === expected)
+ assert(DataType.equalsIgnoreCompatibleNullability(merged, expected))
+ assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+ assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+ assert(merged("c").metadata.getBoolean(StructType.metadataKeyForOptionalField))
}
test("merge where right contains type conflict") {
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 4bb55f6b7f739..89e01fc01596e 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -21,13 +21,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-sql_2.10
+ spark-sql_2.11jarSpark Project SQLhttp://spark.apache.org/
@@ -44,7 +44,7 @@
org.apache.spark
- spark-sketch_2.10
+ spark-sketch_2.11${project.version}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
index 17adfec32192f..b5dddb9f11b22 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
@@ -21,6 +21,7 @@
import java.nio.ByteBuffer;
import java.util.List;
+import org.apache.commons.lang.NotImplementedException;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.parquet.Preconditions;
@@ -41,6 +42,7 @@
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
import org.apache.spark.sql.execution.vectorized.ColumnVector;
import org.apache.spark.sql.execution.vectorized.ColumnarBatch;
+import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.types.UTF8String;
@@ -207,13 +209,7 @@ public boolean nextBatch() throws IOException {
int num = (int)Math.min((long) columnarBatch.capacity(), totalRowCount - rowsReturned);
for (int i = 0; i < columnReaders.length; ++i) {
- switch (columnReaders[i].descriptor.getType()) {
- case INT32:
- columnReaders[i].readIntBatch(num, columnarBatch.column(i));
- break;
- default:
- throw new IOException("Unsupported type: " + columnReaders[i].descriptor.getType());
- }
+ columnReaders[i].readBatch(num, columnarBatch.column(i));
}
rowsReturned += num;
columnarBatch.setNumRows(num);
@@ -237,7 +233,8 @@ private void initializeInternal() throws IOException {
// TODO: Be extremely cautious in what is supported. Expand this.
if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL &&
- originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE) {
+ originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE &&
+ originalTypes[i] != OriginalType.INT_8 && originalTypes[i] != OriginalType.INT_16) {
throw new IOException("Unsupported type: " + t);
}
if (originalTypes[i] == OriginalType.DECIMAL &&
@@ -464,6 +461,11 @@ private final class ColumnReader {
*/
private boolean useDictionary;
+ /**
+ * If useDictionary is true, the staging vector used to decode the ids.
+ */
+ private ColumnVector dictionaryIds;
+
/**
* Maximum definition level for this column.
*/
@@ -587,9 +589,8 @@ private boolean next() throws IOException {
/**
* Reads `total` values from this columnReader into column.
- * TODO: implement the other encodings.
*/
- private void readIntBatch(int total, ColumnVector column) throws IOException {
+ private void readBatch(int total, ColumnVector column) throws IOException {
int rowId = 0;
while (total > 0) {
// Compute the number of values we want to read in this page.
@@ -599,21 +600,134 @@ private void readIntBatch(int total, ColumnVector column) throws IOException {
leftInPage = (int)(endOfPageValueCount - valuesRead);
}
int num = Math.min(total, leftInPage);
- defColumn.readIntegers(
- num, column, rowId, maxDefLevel, (VectorizedValuesReader)dataColumn, 0);
-
- // Remap the values if it is dictionary encoded.
if (useDictionary) {
- for (int i = rowId; i < rowId + num; ++i) {
- column.putInt(i, dictionary.decodeToInt(column.getInt(i)));
+ // Data is dictionary encoded. We will vector decode the ids and then resolve the values.
+ if (dictionaryIds == null) {
+ dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP);
+ } else {
+ dictionaryIds.reset();
+ dictionaryIds.reserve(total);
+ }
+ // Read and decode dictionary ids.
+ readIntBatch(rowId, num, dictionaryIds);
+ decodeDictionaryIds(rowId, num, column);
+ } else {
+ switch (descriptor.getType()) {
+ case INT32:
+ readIntBatch(rowId, num, column);
+ break;
+ case INT64:
+ readLongBatch(rowId, num, column);
+ break;
+ case BINARY:
+ readBinaryBatch(rowId, num, column);
+ break;
+ default:
+ throw new IOException("Unsupported type: " + descriptor.getType());
}
}
+
valuesRead += num;
rowId += num;
total -= num;
}
}
+ /**
+ * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
+ */
+ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) {
+ switch (descriptor.getType()) {
+ case INT32:
+ if (column.dataType() == DataTypes.IntegerType) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
+ }
+ } else if (column.dataType() == DataTypes.ByteType) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putByte(i, (byte)dictionary.decodeToInt(dictionaryIds.getInt(i)));
+ }
+ } else {
+ throw new NotImplementedException("Unimplemented type: " + column.dataType());
+ }
+ break;
+
+ case INT64:
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i)));
+ }
+ break;
+
+ case BINARY:
+ // TODO: this is incredibly inefficient as it blows up the dictionary right here. We
+ // need to do this better. We should probably add the dictionary data to the ColumnVector
+ // and reuse it across batches. This should mean adding a ByteArray would just update
+ // the length and offset.
+ for (int i = rowId; i < rowId + num; ++i) {
+ Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
+ column.putByteArray(i, v.getBytes());
+ }
+ break;
+
+ default:
+ throw new NotImplementedException("Unsupported type: " + descriptor.getType());
+ }
+
+ if (dictionaryIds.numNulls() > 0) {
+ // Copy the NULLs over.
+ // TODO: we can improve this by decoding the NULLs directly into column. This would
+ // mean we decode the int ids into `dictionaryIds` and the NULLs into `column` and then
+ // just do the ID remapping as above.
+ for (int i = 0; i < num; ++i) {
+ if (dictionaryIds.getIsNull(rowId + i)) {
+ column.putNull(rowId + i);
+ }
+ }
+ }
+ }
+
+ /**
+ * For all the read*Batch functions, reads `num` values from this columnReader into column. It
+ * is guaranteed that num is smaller than the number of values left in the current page.
+ */
+
+ private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException {
+ // This is where we implement support for the valid type conversions.
+ // TODO: implement remaining type conversions
+ if (column.dataType() == DataTypes.IntegerType) {
+ defColumn.readIntegers(
+ num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, 0);
+ } else if (column.dataType() == DataTypes.ByteType) {
+ defColumn.readBytes(
+ num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+ } else {
+ throw new NotImplementedException("Unimplemented type: " + column.dataType());
+ }
+ }
+
+ private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException {
+ // This is where we implement support for the valid type conversions.
+ // TODO: implement remaining type conversions
+ if (column.dataType() == DataTypes.LongType) {
+ defColumn.readLongs(
+ num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+ } else {
+ throw new NotImplementedException("Unimplemented type: " + column.dataType());
+ }
+ }
+
+ private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException {
+ // This is where we implement support for the valid type conversions.
+ // TODO: implement remaining type conversions
+ if (column.isArray()) {
+ defColumn.readBinarys(
+ num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+ } else {
+ throw new NotImplementedException("Unimplemented type: " + column.dataType());
+ }
+ }
+
+
private void readPage() throws IOException {
DataPage page = pageReader.readPage();
// TODO: Why is this a visitor?
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
index dac0c52ebd2cf..cec2418e46030 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
@@ -18,10 +18,13 @@
import java.io.IOException;
+import org.apache.spark.sql.Column;
import org.apache.spark.sql.execution.vectorized.ColumnVector;
import org.apache.spark.unsafe.Platform;
+import org.apache.commons.lang.NotImplementedException;
import org.apache.parquet.column.values.ValuesReader;
+import org.apache.parquet.io.api.Binary;
/**
* An implementation of the Parquet PLAIN decoder that supports the vectorized interface.
@@ -52,15 +55,53 @@ public void skip(int n) {
}
@Override
- public void readIntegers(int total, ColumnVector c, int rowId) {
+ public final void readIntegers(int total, ColumnVector c, int rowId) {
c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
offset += 4 * total;
}
@Override
- public int readInteger() {
+ public final void readLongs(int total, ColumnVector c, int rowId) {
+ c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
+ offset += 8 * total;
+ }
+
+ @Override
+ public final void readBytes(int total, ColumnVector c, int rowId) {
+ for (int i = 0; i < total; i++) {
+ // Bytes are stored as a 4-byte little endian int. Just read the first byte.
+ // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride.
+ c.putInt(rowId + i, buffer[offset]);
+ offset += 4;
+ }
+ }
+
+ @Override
+ public final int readInteger() {
int v = Platform.getInt(buffer, offset);
offset += 4;
return v;
}
+
+ @Override
+ public final long readLong() {
+ long v = Platform.getLong(buffer, offset);
+ offset += 8;
+ return v;
+ }
+
+ @Override
+ public final byte readByte() {
+ return (byte)readInteger();
+ }
+
+ @Override
+ public final void readBinary(int total, ColumnVector v, int rowId) {
+ for (int i = 0; i < total; i++) {
+ int len = readInteger();
+ int start = offset;
+ offset += len;
+ v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len);
+ }
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
index 493ec9deed499..9bfd74db38766 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
@@ -17,12 +17,16 @@
package org.apache.spark.sql.execution.datasources.parquet;
+import org.apache.commons.lang.NotImplementedException;
import org.apache.parquet.Preconditions;
import org.apache.parquet.bytes.BytesUtils;
import org.apache.parquet.column.values.ValuesReader;
import org.apache.parquet.column.values.bitpacking.BytePacker;
import org.apache.parquet.column.values.bitpacking.Packer;
import org.apache.parquet.io.ParquetDecodingException;
+import org.apache.parquet.io.api.Binary;
+
+import org.apache.spark.sql.Column;
import org.apache.spark.sql.execution.vectorized.ColumnVector;
/**
@@ -35,7 +39,8 @@
* - Definition/Repetition levels
* - Dictionary ids.
*/
-public final class VectorizedRleValuesReader extends ValuesReader {
+public final class VectorizedRleValuesReader extends ValuesReader
+ implements VectorizedValuesReader {
// Current decoding mode. The encoded data contains groups of either run length encoded data
// (RLE) or bit packed data. Each group contains a header that indicates which group it is and
// the number of values in the group.
@@ -121,6 +126,7 @@ public int readValueDictionaryId() {
return readInteger();
}
+
@Override
public int readInteger() {
if (this.currentCount == 0) { this.readNextGroup(); }
@@ -138,7 +144,9 @@ public int readInteger() {
/**
* Reads `total` ints into `c` filling them in starting at `c[rowId]`. This reader
* reads the definition levels and then will read from `data` for the non-null values.
- * If the value is null, c will be populated with `nullValue`.
+ * If the value is null, c will be populated with `nullValue`. Note that `nullValue` is only
+ * necessary for readIntegers because we also use it to decode dictionaryIds and want to make
+ * sure it always has a value in range.
*
* This is a batched version of this logic:
* if (this.readInt() == level) {
@@ -180,6 +188,154 @@ public void readIntegers(int total, ColumnVector c, int rowId, int level,
}
}
+ // TODO: can this code duplication be removed without a perf penalty?
+ public void readBytes(int total, ColumnVector c,
+ int rowId, int level, VectorizedValuesReader data) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ switch (mode) {
+ case RLE:
+ if (currentValue == level) {
+ data.readBytes(n, c, rowId);
+ c.putNotNulls(rowId, n);
+ } else {
+ c.putNulls(rowId, n);
+ }
+ break;
+ case PACKED:
+ for (int i = 0; i < n; ++i) {
+ if (currentBuffer[currentBufferIdx++] == level) {
+ c.putByte(rowId + i, data.readByte());
+ c.putNotNull(rowId + i);
+ } else {
+ c.putNull(rowId + i);
+ }
+ }
+ break;
+ }
+ rowId += n;
+ left -= n;
+ currentCount -= n;
+ }
+ }
+
+ public void readLongs(int total, ColumnVector c, int rowId, int level,
+ VectorizedValuesReader data) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ switch (mode) {
+ case RLE:
+ if (currentValue == level) {
+ data.readLongs(n, c, rowId);
+ c.putNotNulls(rowId, n);
+ } else {
+ c.putNulls(rowId, n);
+ }
+ break;
+ case PACKED:
+ for (int i = 0; i < n; ++i) {
+ if (currentBuffer[currentBufferIdx++] == level) {
+ c.putLong(rowId + i, data.readLong());
+ c.putNotNull(rowId + i);
+ } else {
+ c.putNull(rowId + i);
+ }
+ }
+ break;
+ }
+ rowId += n;
+ left -= n;
+ currentCount -= n;
+ }
+ }
+
+ public void readBinarys(int total, ColumnVector c, int rowId, int level,
+ VectorizedValuesReader data) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ switch (mode) {
+ case RLE:
+ if (currentValue == level) {
+ c.putNotNulls(rowId, n);
+ data.readBinary(n, c, rowId);
+ } else {
+ c.putNulls(rowId, n);
+ }
+ break;
+ case PACKED:
+ for (int i = 0; i < n; ++i) {
+ if (currentBuffer[currentBufferIdx++] == level) {
+ c.putNotNull(rowId + i);
+ data.readBinary(1, c, rowId);
+ } else {
+ c.putNull(rowId + i);
+ }
+ }
+ break;
+ }
+ rowId += n;
+ left -= n;
+ currentCount -= n;
+ }
+ }
+
+
+ // The RLE reader implements the vectorized decoding interface when used to decode dictionary
+ // IDs. This is different than the above APIs that decodes definitions levels along with values.
+ // Since this is only used to decode dictionary IDs, only decoding integers is supported.
+ @Override
+ public void readIntegers(int total, ColumnVector c, int rowId) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ switch (mode) {
+ case RLE:
+ c.putInts(rowId, n, currentValue);
+ break;
+ case PACKED:
+ c.putInts(rowId, n, currentBuffer, currentBufferIdx);
+ currentBufferIdx += n;
+ break;
+ }
+ rowId += n;
+ left -= n;
+ currentCount -= n;
+ }
+ }
+
+ @Override
+ public byte readByte() {
+ throw new UnsupportedOperationException("only readInts is valid.");
+ }
+
+ @Override
+ public void readBytes(int total, ColumnVector c, int rowId) {
+ throw new UnsupportedOperationException("only readInts is valid.");
+ }
+
+ @Override
+ public void readLongs(int total, ColumnVector c, int rowId) {
+ throw new UnsupportedOperationException("only readInts is valid.");
+ }
+
+ @Override
+ public void readBinary(int total, ColumnVector c, int rowId) {
+ throw new UnsupportedOperationException("only readInts is valid.");
+ }
+
+ @Override
+ public void skip(int n) {
+ throw new UnsupportedOperationException("only readInts is valid.");
+ }
+
+
/**
* Reads the next varint encoded int.
*/
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
index 49a9ed83d590a..b6ec7311c564a 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
@@ -24,12 +24,17 @@
* TODO: merge this into parquet-mr.
*/
public interface VectorizedValuesReader {
+ byte readByte();
int readInteger();
+ long readLong();
/*
* Reads `total` values into `c` start at `c[rowId]`
*/
+ void readBytes(int total, ColumnVector c, int rowId);
void readIntegers(int total, ColumnVector c, int rowId);
+ void readLongs(int total, ColumnVector c, int rowId);
+ void readBinary(int total, ColumnVector c, int rowId);
// TODO: add all the other parquet types.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index c119758d68b36..0514252a8e53d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -16,6 +16,9 @@
*/
package org.apache.spark.sql.execution.vectorized;
+import java.math.BigDecimal;
+import java.math.BigInteger;
+
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
@@ -102,18 +105,36 @@ public Object[] array() {
DataType dt = data.dataType();
Object[] list = new Object[length];
- if (dt instanceof ByteType) {
+ if (dt instanceof BooleanType) {
+ for (int i = 0; i < length; i++) {
+ if (!data.getIsNull(offset + i)) {
+ list[i] = data.getBoolean(offset + i);
+ }
+ }
+ } else if (dt instanceof ByteType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = data.getByte(offset + i);
}
}
+ } else if (dt instanceof ShortType) {
+ for (int i = 0; i < length; i++) {
+ if (!data.getIsNull(offset + i)) {
+ list[i] = data.getShort(offset + i);
+ }
+ }
} else if (dt instanceof IntegerType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = data.getInt(offset + i);
}
}
+ } else if (dt instanceof FloatType) {
+ for (int i = 0; i < length; i++) {
+ if (!data.getIsNull(offset + i)) {
+ list[i] = data.getFloat(offset + i);
+ }
+ }
} else if (dt instanceof DoubleType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
@@ -126,12 +147,25 @@ public Object[] array() {
list[i] = data.getLong(offset + i);
}
}
+ } else if (dt instanceof DecimalType) {
+ DecimalType decType = (DecimalType)dt;
+ for (int i = 0; i < length; i++) {
+ if (!data.getIsNull(offset + i)) {
+ list[i] = getDecimal(i, decType.precision(), decType.scale());
+ }
+ }
} else if (dt instanceof StringType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i));
}
}
+ } else if (dt instanceof CalendarIntervalType) {
+ for (int i = 0; i < length; i++) {
+ if (!data.getIsNull(offset + i)) {
+ list[i] = getInterval(i);
+ }
+ }
} else {
throw new NotImplementedException("Type " + dt);
}
@@ -170,7 +204,14 @@ public float getFloat(int ordinal) {
@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
- throw new NotImplementedException();
+ if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ return Decimal.apply(getLong(ordinal), precision, scale);
+ } else {
+ byte[] bytes = getBinary(ordinal);
+ BigInteger bigInteger = new BigInteger(bytes);
+ BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
+ return Decimal.apply(javaDecimal, precision, scale);
+ }
}
@Override
@@ -181,17 +222,22 @@ public UTF8String getUTF8String(int ordinal) {
@Override
public byte[] getBinary(int ordinal) {
- throw new NotImplementedException();
+ ColumnVector.Array array = data.getByteArray(offset + ordinal);
+ byte[] bytes = new byte[array.length];
+ System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
+ return bytes;
}
@Override
public CalendarInterval getInterval(int ordinal) {
- throw new NotImplementedException();
+ int month = data.getChildColumn(0).getInt(offset + ordinal);
+ long microseconds = data.getChildColumn(1).getLong(offset + ordinal);
+ return new CalendarInterval(month, microseconds);
}
@Override
public InternalRow getStruct(int ordinal, int numFields) {
- throw new NotImplementedException();
+ return data.getStruct(offset + ordinal);
}
@Override
@@ -210,104 +256,6 @@ public Object get(int ordinal, DataType dataType) {
}
}
- /**
- * Holder object to return a struct. This object is intended to be reused.
- */
- public static final class Struct extends InternalRow {
- // The fields that make up this struct. For example, if the struct had 2 int fields, the access
- // to it would be:
- // int f1 = fields[0].getInt[rowId]
- // int f2 = fields[1].getInt[rowId]
- public final ColumnVector[] fields;
-
- @Override
- public boolean isNullAt(int fieldIdx) { return fields[fieldIdx].getIsNull(rowId); }
-
- @Override
- public boolean getBoolean(int ordinal) {
- throw new NotImplementedException();
- }
-
- public byte getByte(int fieldIdx) { return fields[fieldIdx].getByte(rowId); }
-
- @Override
- public short getShort(int ordinal) {
- throw new NotImplementedException();
- }
-
- public int getInt(int fieldIdx) { return fields[fieldIdx].getInt(rowId); }
- public long getLong(int fieldIdx) { return fields[fieldIdx].getLong(rowId); }
-
- @Override
- public float getFloat(int ordinal) {
- throw new NotImplementedException();
- }
-
- public double getDouble(int fieldIdx) { return fields[fieldIdx].getDouble(rowId); }
-
- @Override
- public Decimal getDecimal(int ordinal, int precision, int scale) {
- throw new NotImplementedException();
- }
-
- @Override
- public UTF8String getUTF8String(int ordinal) {
- Array a = getByteArray(ordinal);
- return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length);
- }
-
- @Override
- public byte[] getBinary(int ordinal) {
- throw new NotImplementedException();
- }
-
- @Override
- public CalendarInterval getInterval(int ordinal) {
- throw new NotImplementedException();
- }
-
- @Override
- public InternalRow getStruct(int ordinal, int numFields) {
- return fields[ordinal].getStruct(rowId);
- }
-
- public Array getArray(int fieldIdx) { return fields[fieldIdx].getArray(rowId); }
-
- @Override
- public MapData getMap(int ordinal) {
- throw new NotImplementedException();
- }
-
- @Override
- public Object get(int ordinal, DataType dataType) {
- throw new NotImplementedException();
- }
-
- public Array getByteArray(int fieldIdx) { return fields[fieldIdx].getByteArray(rowId); }
- public Struct getStruct(int fieldIdx) { return fields[fieldIdx].getStruct(rowId); }
-
- @Override
- public final int numFields() {
- return fields.length;
- }
-
- @Override
- public InternalRow copy() {
- throw new NotImplementedException();
- }
-
- @Override
- public boolean anyNull() {
- throw new NotImplementedException();
- }
-
- protected int rowId;
-
- protected Struct(ColumnVector[] fields) {
- this.fields = fields;
- }
- }
-
/**
* Returns the data type of this column.
*/
@@ -377,6 +325,21 @@ public void reset() {
*/
public abstract boolean getIsNull(int rowId);
+ /**
+ * Sets the value at rowId to `value`.
+ */
+ public abstract void putBoolean(int rowId, boolean value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to value.
+ */
+ public abstract void putBooleans(int rowId, int count, boolean value);
+
+ /**
+ * Returns the value for rowId.
+ */
+ public abstract boolean getBoolean(int rowId);
+
/**
* Sets the value at rowId to `value`.
*/
@@ -397,6 +360,26 @@ public void reset() {
*/
public abstract byte getByte(int rowId);
+ /**
+ * Sets the value at rowId to `value`.
+ */
+ public abstract void putShort(int rowId, short value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to value.
+ */
+ public abstract void putShorts(int rowId, int count, short value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
+ */
+ public abstract void putShorts(int rowId, int count, short[] src, int srcIndex);
+
+ /**
+ * Returns the value for rowId.
+ */
+ public abstract short getShort(int rowId);
+
/**
* Sets the value at rowId to `value`.
*/
@@ -449,6 +432,33 @@ public void reset() {
*/
public abstract long getLong(int rowId);
+ /**
+ * Sets the value at rowId to `value`.
+ */
+ public abstract void putFloat(int rowId, float value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to value.
+ */
+ public abstract void putFloats(int rowId, int count, float value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
+ * src should contain `count` doubles written as ieee format.
+ */
+ public abstract void putFloats(int rowId, int count, float[] src, int srcIndex);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
+ * The data in src must be ieee formatted floats.
+ */
+ public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex);
+
+ /**
+ * Returns the value for rowId.
+ */
+ public abstract float getFloat(int rowId);
+
/**
* Sets the value at rowId to `value`.
*/
@@ -467,7 +477,7 @@ public void reset() {
/**
* Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
- * The data in src must be ieee formated doubles.
+ * The data in src must be ieee formatted doubles.
*/
public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex);
@@ -494,7 +504,7 @@ public void reset() {
/**
* Returns a utility object to get structs.
*/
- public Struct getStruct(int rowId) {
+ public ColumnarBatch.Row getStruct(int rowId) {
resultStruct.rowId = rowId;
return resultStruct;
}
@@ -567,6 +577,20 @@ public final int appendNotNulls(int count) {
return result;
}
+ public final int appendBoolean(boolean v) {
+ reserve(elementsAppended + 1);
+ putBoolean(elementsAppended, v);
+ return elementsAppended++;
+ }
+
+ public final int appendBooleans(int count, boolean v) {
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putBooleans(elementsAppended, count, v);
+ elementsAppended += count;
+ return result;
+ }
+
public final int appendByte(byte v) {
reserve(elementsAppended + 1);
putByte(elementsAppended, v);
@@ -589,6 +613,28 @@ public final int appendBytes(int length, byte[] src, int offset) {
return result;
}
+ public final int appendShort(short v) {
+ reserve(elementsAppended + 1);
+ putShort(elementsAppended, v);
+ return elementsAppended++;
+ }
+
+ public final int appendShorts(int count, short v) {
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putShorts(elementsAppended, count, v);
+ elementsAppended += count;
+ return result;
+ }
+
+ public final int appendShorts(int length, short[] src, int offset) {
+ reserve(elementsAppended + length);
+ int result = elementsAppended;
+ putShorts(elementsAppended, length, src, offset);
+ elementsAppended += length;
+ return result;
+ }
+
public final int appendInt(int v) {
reserve(elementsAppended + 1);
putInt(elementsAppended, v);
@@ -633,6 +679,20 @@ public final int appendLongs(int length, long[] src, int offset) {
return result;
}
+ public final int appendFloat(float v) {
+ reserve(elementsAppended + 1);
+ putFloat(elementsAppended, v);
+ return elementsAppended++;
+ }
+
+ public final int appendFloats(int count, float v) {
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putFloats(elementsAppended, count, v);
+ elementsAppended += count;
+ return result;
+ }
+
public final int appendDouble(double v) {
reserve(elementsAppended + 1);
putDouble(elementsAppended, v);
@@ -703,7 +763,12 @@ public final int appendStruct(boolean isNull) {
/**
* Returns the elements appended.
*/
- public int getElementsAppended() { return elementsAppended; }
+ public final int getElementsAppended() { return elementsAppended; }
+
+ /**
+ * Returns true if this column is an array.
+ */
+ public final boolean isArray() { return resultArray != null; }
/**
* Maximum number of rows that can be stored in this column.
@@ -749,7 +814,7 @@ public final int appendStruct(boolean isNull) {
/**
* Reusable Struct holder for getStruct().
*/
- protected final Struct resultStruct;
+ protected final ColumnarBatch.Row resultStruct;
/**
* Sets up the common state and also handles creating the child columns if this is a nested
@@ -759,7 +824,8 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) {
this.capacity = capacity;
this.type = type;
- if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType) {
+ if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType
+ || DecimalType.isByteArrayDecimalType(type)) {
DataType childType;
int childCapacity = capacity;
if (type instanceof ArrayType) {
@@ -779,7 +845,14 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) {
this.childColumns[i] = ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode);
}
this.resultArray = null;
- this.resultStruct = new Struct(this.childColumns);
+ this.resultStruct = new ColumnarBatch.Row(this.childColumns);
+ } else if (type instanceof CalendarIntervalType) {
+ // Two columns. Months as int. Microseconds as Long.
+ this.childColumns = new ColumnVector[2];
+ this.childColumns[0] = ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode);
+ this.childColumns[1] = ColumnVector.allocate(capacity, DataTypes.LongType, memMode);
+ this.resultArray = null;
+ this.resultStruct = new ColumnarBatch.Row(this.childColumns);
} else {
this.childColumns = null;
this.resultArray = null;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
index 6c651a759d250..453bc15e13503 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
@@ -16,12 +16,15 @@
*/
package org.apache.spark.sql.execution.vectorized;
+import java.math.BigDecimal;
+import java.math.BigInteger;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.commons.lang.NotImplementedException;
@@ -59,19 +62,44 @@ public static Object toPrimitiveJavaArray(ColumnVector.Array array) {
private static void appendValue(ColumnVector dst, DataType t, Object o) {
if (o == null) {
- dst.appendNull();
+ if (t instanceof CalendarIntervalType) {
+ dst.appendStruct(true);
+ } else {
+ dst.appendNull();
+ }
} else {
- if (t == DataTypes.ByteType) {
- dst.appendByte(((Byte)o).byteValue());
+ if (t == DataTypes.BooleanType) {
+ dst.appendBoolean(((Boolean)o).booleanValue());
+ } else if (t == DataTypes.ByteType) {
+ dst.appendByte(((Byte) o).byteValue());
+ } else if (t == DataTypes.ShortType) {
+ dst.appendShort(((Short)o).shortValue());
} else if (t == DataTypes.IntegerType) {
dst.appendInt(((Integer)o).intValue());
} else if (t == DataTypes.LongType) {
dst.appendLong(((Long)o).longValue());
+ } else if (t == DataTypes.FloatType) {
+ dst.appendFloat(((Float)o).floatValue());
} else if (t == DataTypes.DoubleType) {
dst.appendDouble(((Double)o).doubleValue());
} else if (t == DataTypes.StringType) {
byte[] b =((String)o).getBytes();
dst.appendByteArray(b, 0, b.length);
+ } else if (t instanceof DecimalType) {
+ DecimalType dt = (DecimalType)t;
+ Decimal d = Decimal.apply((BigDecimal)o, dt.precision(), dt.scale());
+ if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) {
+ dst.appendLong(d.toUnscaledLong());
+ } else {
+ final BigInteger integer = d.toJavaBigDecimal().unscaledValue();
+ byte[] bytes = integer.toByteArray();
+ dst.appendByteArray(bytes, 0, bytes.length);
+ }
+ } else if (t instanceof CalendarIntervalType) {
+ CalendarInterval c = (CalendarInterval)o;
+ dst.appendStruct(false);
+ dst.getChildColumn(0).appendInt(c.months);
+ dst.getChildColumn(1).appendLong(c.microseconds);
} else {
throw new NotImplementedException("Type " + t);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index d558dae50c227..dbad5e070f1fe 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.execution.vectorized;
+import java.math.BigDecimal;
+import java.math.BigInteger;
import java.util.Arrays;
import java.util.Iterator;
@@ -25,6 +27,7 @@
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@@ -86,13 +89,23 @@ public void close() {
* performance is lost with this translation.
*/
public static final class Row extends InternalRow {
- private int rowId;
+ protected int rowId;
private final ColumnarBatch parent;
private final int fixedLenRowSize;
+ private final ColumnVector[] columns;
+ // Ctor used if this is a top level row.
private Row(ColumnarBatch parent) {
this.parent = parent;
this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols());
+ this.columns = parent.columns;
+ }
+
+ // Ctor used if this is a struct.
+ protected Row(ColumnVector[] columns) {
+ this.parent = null;
+ this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length);
+ this.columns = columns;
}
/**
@@ -103,23 +116,23 @@ public final void markFiltered() {
parent.markFiltered(rowId);
}
+ public ColumnVector[] columns() { return columns; }
+
@Override
- public final int numFields() {
- return parent.numCols();
- }
+ public final int numFields() { return columns.length; }
@Override
/**
* Revisit this. This is expensive.
*/
public final InternalRow copy() {
- UnsafeRow row = new UnsafeRow(parent.numCols());
+ UnsafeRow row = new UnsafeRow(numFields());
row.pointTo(new byte[fixedLenRowSize], fixedLenRowSize);
- for (int i = 0; i < parent.numCols(); i++) {
+ for (int i = 0; i < numFields(); i++) {
if (isNullAt(i)) {
row.setNullAt(i);
} else {
- DataType dt = parent.schema.fields()[i].dataType();
+ DataType dt = columns[i].dataType();
if (dt instanceof IntegerType) {
row.setInt(i, getInt(i));
} else if (dt instanceof LongType) {
@@ -140,70 +153,71 @@ public final boolean anyNull() {
}
@Override
- public final boolean isNullAt(int ordinal) {
- return parent.column(ordinal).getIsNull(rowId);
- }
+ public final boolean isNullAt(int ordinal) { return columns[ordinal].getIsNull(rowId); }
@Override
- public final boolean getBoolean(int ordinal) {
- throw new NotImplementedException();
- }
+ public final boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); }
@Override
- public final byte getByte(int ordinal) { return parent.column(ordinal).getByte(rowId); }
+ public final byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); }
@Override
- public final short getShort(int ordinal) {
- throw new NotImplementedException();
- }
+ public final short getShort(int ordinal) { return columns[ordinal].getShort(rowId); }
@Override
- public final int getInt(int ordinal) {
- return parent.column(ordinal).getInt(rowId);
- }
+ public final int getInt(int ordinal) { return columns[ordinal].getInt(rowId); }
@Override
- public final long getLong(int ordinal) { return parent.column(ordinal).getLong(rowId); }
+ public final long getLong(int ordinal) { return columns[ordinal].getLong(rowId); }
@Override
- public final float getFloat(int ordinal) {
- throw new NotImplementedException();
- }
+ public final float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); }
@Override
- public final double getDouble(int ordinal) {
- return parent.column(ordinal).getDouble(rowId);
- }
+ public final double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); }
@Override
public final Decimal getDecimal(int ordinal, int precision, int scale) {
- throw new NotImplementedException();
+ if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ return Decimal.apply(getLong(ordinal), precision, scale);
+ } else {
+ // TODO: best perf?
+ byte[] bytes = getBinary(ordinal);
+ BigInteger bigInteger = new BigInteger(bytes);
+ BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
+ return Decimal.apply(javaDecimal, precision, scale);
+ }
}
@Override
public final UTF8String getUTF8String(int ordinal) {
- ColumnVector.Array a = parent.column(ordinal).getByteArray(rowId);
+ ColumnVector.Array a = columns[ordinal].getByteArray(rowId);
return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length);
}
@Override
public final byte[] getBinary(int ordinal) {
- throw new NotImplementedException();
+ ColumnVector.Array array = columns[ordinal].getByteArray(rowId);
+ byte[] bytes = new byte[array.length];
+ System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
+ return bytes;
}
@Override
public final CalendarInterval getInterval(int ordinal) {
- throw new NotImplementedException();
+ final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
+ final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
+ return new CalendarInterval(months, microseconds);
}
@Override
public final InternalRow getStruct(int ordinal, int numFields) {
- return parent.column(ordinal).getStruct(rowId);
+ return columns[ordinal].getStruct(rowId);
}
@Override
public final ArrayData getArray(int ordinal) {
- return parent.column(ordinal).getArray(rowId);
+ return columns[ordinal].getArray(rowId);
}
@Override
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index 335124fd5a603..7a224d19d15b7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -19,11 +19,15 @@
import java.nio.ByteOrder;
import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.DoubleType;
+import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
+import org.apache.spark.sql.types.ShortType;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.types.UTF8String;
@@ -121,6 +125,26 @@ public final boolean getIsNull(int rowId) {
return Platform.getByte(null, nulls + rowId) == 1;
}
+ //
+ // APIs dealing with Booleans
+ //
+
+ @Override
+ public final void putBoolean(int rowId, boolean value) {
+ Platform.putByte(null, data + rowId, (byte)((value) ? 1 : 0));
+ }
+
+ @Override
+ public final void putBooleans(int rowId, int count, boolean value) {
+ byte v = (byte)((value) ? 1 : 0);
+ for (int i = 0; i < count; ++i) {
+ Platform.putByte(null, data + rowId + i, v);
+ }
+ }
+
+ @Override
+ public final boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; }
+
//
// APIs dealing with Bytes
//
@@ -148,6 +172,34 @@ public final byte getByte(int rowId) {
return Platform.getByte(null, data + rowId);
}
+ //
+ // APIs dealing with shorts
+ //
+
+ @Override
+ public final void putShort(int rowId, short value) {
+ Platform.putShort(null, data + 2 * rowId, value);
+ }
+
+ @Override
+ public final void putShorts(int rowId, int count, short value) {
+ long offset = data + 2 * rowId;
+ for (int i = 0; i < count; ++i, offset += 4) {
+ Platform.putShort(null, offset, value);
+ }
+ }
+
+ @Override
+ public final void putShorts(int rowId, int count, short[] src, int srcIndex) {
+ Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2,
+ null, data + 2 * rowId, count * 2);
+ }
+
+ @Override
+ public final short getShort(int rowId) {
+ return Platform.getShort(null, data + 2 * rowId);
+ }
+
//
// APIs dealing with ints
//
@@ -216,6 +268,41 @@ public final long getLong(int rowId) {
return Platform.getLong(null, data + 8 * rowId);
}
+ //
+ // APIs dealing with floats
+ //
+
+ @Override
+ public final void putFloat(int rowId, float value) {
+ Platform.putFloat(null, data + rowId * 4, value);
+ }
+
+ @Override
+ public final void putFloats(int rowId, int count, float value) {
+ long offset = data + 4 * rowId;
+ for (int i = 0; i < count; ++i, offset += 4) {
+ Platform.putFloat(null, offset, value);
+ }
+ }
+
+ @Override
+ public final void putFloats(int rowId, int count, float[] src, int srcIndex) {
+ Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4,
+ null, data + 4 * rowId, count * 4);
+ }
+
+ @Override
+ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) {
+ Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
+ null, data + rowId * 4, count * 4);
+ }
+
+ @Override
+ public final float getFloat(int rowId) {
+ return Platform.getFloat(null, data + rowId * 4);
+ }
+
+
//
// APIs dealing with doubles
//
@@ -241,7 +328,7 @@ public final void putDoubles(int rowId, int count, double[] src, int srcIndex) {
@Override
public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) {
- Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex,
+ Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
null, data + rowId * 8, count * 8);
}
@@ -280,7 +367,7 @@ public final int putByteArray(int rowId, byte[] value, int offset, int length) {
}
@Override
- public final void loadBytes(Array array) {
+ public final void loadBytes(ColumnVector.Array array) {
if (array.tmpByteArray.length < array.length) array.tmpByteArray = new byte[array.length];
Platform.copyMemory(
null, data + array.offset, array.tmpByteArray, Platform.BYTE_ARRAY_OFFSET, array.length);
@@ -300,11 +387,14 @@ private final void reserveInternal(int newCapacity) {
Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4);
this.offsetData =
Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4);
- } else if (type instanceof ByteType) {
+ } else if (type instanceof ByteType || type instanceof BooleanType) {
this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity);
- } else if (type instanceof IntegerType) {
+ } else if (type instanceof ShortType) {
+ this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2);
+ } else if (type instanceof IntegerType || type instanceof FloatType) {
this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4);
- } else if (type instanceof LongType || type instanceof DoubleType) {
+ } else if (type instanceof LongType || type instanceof DoubleType ||
+ DecimalType.is64BitDecimalType(type)) {
this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8);
} else if (resultStruct != null) {
// Nothing to store.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index 8197fa11cd4c8..c42bbd642ecae 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -35,8 +35,10 @@ public final class OnHeapColumnVector extends ColumnVector {
// Array for each type. Only 1 is populated for any type.
private byte[] byteData;
+ private short[] shortData;
private int[] intData;
private long[] longData;
+ private float[] floatData;
private double[] doubleData;
// Only set if type is Array.
@@ -104,6 +106,30 @@ public final boolean getIsNull(int rowId) {
return nulls[rowId] == 1;
}
+ //
+ // APIs dealing with Booleans
+ //
+
+ @Override
+ public final void putBoolean(int rowId, boolean value) {
+ byteData[rowId] = (byte)((value) ? 1 : 0);
+ }
+
+ @Override
+ public final void putBooleans(int rowId, int count, boolean value) {
+ byte v = (byte)((value) ? 1 : 0);
+ for (int i = 0; i < count; ++i) {
+ byteData[i + rowId] = v;
+ }
+ }
+
+ @Override
+ public final boolean getBoolean(int rowId) {
+ return byteData[rowId] == 1;
+ }
+
+ //
+
//
// APIs dealing with Bytes
//
@@ -130,6 +156,33 @@ public final byte getByte(int rowId) {
return byteData[rowId];
}
+ //
+ // APIs dealing with Shorts
+ //
+
+ @Override
+ public final void putShort(int rowId, short value) {
+ shortData[rowId] = value;
+ }
+
+ @Override
+ public final void putShorts(int rowId, int count, short value) {
+ for (int i = 0; i < count; ++i) {
+ shortData[i + rowId] = value;
+ }
+ }
+
+ @Override
+ public final void putShorts(int rowId, int count, short[] src, int srcIndex) {
+ System.arraycopy(src, srcIndex, shortData, rowId, count);
+ }
+
+ @Override
+ public final short getShort(int rowId) {
+ return shortData[rowId];
+ }
+
+
//
// APIs dealing with Ints
//
@@ -202,6 +255,31 @@ public final long getLong(int rowId) {
return longData[rowId];
}
+ //
+ // APIs dealing with floats
+ //
+
+ @Override
+ public final void putFloat(int rowId, float value) { floatData[rowId] = value; }
+
+ @Override
+ public final void putFloats(int rowId, int count, float value) {
+ Arrays.fill(floatData, rowId, rowId + count, value);
+ }
+
+ @Override
+ public final void putFloats(int rowId, int count, float[] src, int srcIndex) {
+ System.arraycopy(src, srcIndex, floatData, rowId, count);
+ }
+
+ @Override
+ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) {
+ Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
+ floatData, Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4);
+ }
+
+ @Override
+ public final float getFloat(int rowId) { return floatData[rowId]; }
//
// APIs dealing with doubles
@@ -253,7 +331,7 @@ public final void putArray(int rowId, int offset, int length) {
}
@Override
- public final void loadBytes(Array array) {
+ public final void loadBytes(ColumnVector.Array array) {
array.byteArray = byteData;
array.byteArrayOffset = array.offset;
}
@@ -277,7 +355,7 @@ public final void reserve(int requiredCapacity) {
// Spilt this function out since it is the slow path.
private final void reserveInternal(int newCapacity) {
- if (this.resultArray != null) {
+ if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) {
int[] newLengths = new int[newCapacity];
int[] newOffsets = new int[newCapacity];
if (this.arrayLengths != null) {
@@ -286,18 +364,30 @@ private final void reserveInternal(int newCapacity) {
}
arrayLengths = newLengths;
arrayOffsets = newOffsets;
+ } else if (type instanceof BooleanType) {
+ byte[] newData = new byte[newCapacity];
+ if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended);
+ byteData = newData;
} else if (type instanceof ByteType) {
byte[] newData = new byte[newCapacity];
if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended);
byteData = newData;
+ } else if (type instanceof ShortType) {
+ short[] newData = new short[newCapacity];
+ if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended);
+ shortData = newData;
} else if (type instanceof IntegerType) {
int[] newData = new int[newCapacity];
if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended);
intData = newData;
- } else if (type instanceof LongType) {
+ } else if (type instanceof LongType || DecimalType.is64BitDecimalType(type)) {
long[] newData = new long[newCapacity];
if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended);
longData = newData;
+ } else if (type instanceof FloatType) {
+ float[] newData = new float[newCapacity];
+ if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended);
+ floatData = newData;
} else if (type instanceof DoubleType) {
double[] newData = new double[newCapacity];
if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala
new file mode 100644
index 0000000000000..1c2c0290fc4cd
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+/**
+ * A handle to a query that is executing continuously in the background as new data arrives.
+ */
+trait ContinuousQuery {
+
+ /**
+ * Stops the execution of this query if it is running. This method blocks until the threads
+ * performing execution has stopped.
+ */
+ def stop(): Unit
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 518f9dcf94a70..7aa08fb63053b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -474,6 +474,7 @@ class DataFrame private[sql](
val rightCol = withPlan(joined.right).resolve(col).toAttribute.withNullability(true)
Alias(Coalesce(Seq(leftCol, rightCol)), col)()
}
+ case NaturalJoin(_) => sys.error("NaturalJoin with using clause is not supported.")
}
// The nullability of output of joined could be different than original column,
// so we can only compare them by exprId
@@ -1383,6 +1384,10 @@ class DataFrame private[sql](
/**
* Returns the first `n` rows.
+ *
+ * @note this method should only be used if the resulting array is expected to be small, as
+ * all the data is loaded into the driver's memory.
+ *
* @group action
* @since 1.3.0
*/
@@ -1682,7 +1687,7 @@ class DataFrame private[sql](
/**
* :: Experimental ::
- * Interface for saving the content of the [[DataFrame]] out into external storage.
+ * Interface for saving the content of the [[DataFrame]] out into external storage or streams.
*
* @group output
* @since 1.4.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 634c1bd4739b1..962fdadf1431d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -29,17 +29,17 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.{CatalystQl}
import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource}
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
import org.apache.spark.sql.execution.datasources.json.JSONRelation
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
+import org.apache.spark.sql.execution.streaming.StreamingRelation
import org.apache.spark.sql.types.StructType
/**
* :: Experimental ::
* Interface used to load a [[DataFrame]] from external storage systems (e.g. file systems,
- * key-value stores, etc). Use [[SQLContext.read]] to access this.
+ * key-value stores, etc) or data streams. Use [[SQLContext.read]] to access this.
*
* @since 1.4.0
*/
@@ -78,6 +78,27 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
this
}
+ /**
+ * Adds an input option for the underlying data source.
+ *
+ * @since 2.0.0
+ */
+ def option(key: String, value: Boolean): DataFrameReader = option(key, value.toString)
+
+ /**
+ * Adds an input option for the underlying data source.
+ *
+ * @since 2.0.0
+ */
+ def option(key: String, value: Long): DataFrameReader = option(key, value.toString)
+
+ /**
+ * Adds an input option for the underlying data source.
+ *
+ * @since 2.0.0
+ */
+ def option(key: String, value: Double): DataFrameReader = option(key, value.toString)
+
/**
* (Scala-specific) Adds input options for the underlying data source.
*
@@ -136,6 +157,30 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load()
}
+ /**
+ * Loads input data stream in as a [[DataFrame]], for data streams that don't require a path
+ * (e.g. external key-value stores).
+ *
+ * @since 2.0.0
+ */
+ def stream(): DataFrame = {
+ val resolved = ResolvedDataSource.createSource(
+ sqlContext,
+ userSpecifiedSchema = userSpecifiedSchema,
+ providerName = source,
+ options = extraOptions.toMap)
+ DataFrame(sqlContext, StreamingRelation(resolved))
+ }
+
+ /**
+ * Loads input in as a [[DataFrame]], for data streams that read from some path.
+ *
+ * @since 2.0.0
+ */
+ def stream(path: String): DataFrame = {
+ option("path", path).stream()
+ }
+
/**
* Construct a [[DataFrame]] representing the database table accessible via JDBC URL
* url named table and connection properties.
@@ -165,7 +210,6 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
* @param connectionProperties JDBC database connection arguments, a list of arbitrary string
* tag/value. Normally at least a "user" and "password" property
* should be included.
- *
* @since 1.4.0
*/
def jdbc(
@@ -252,6 +296,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*
* You can set the following JSON-specific options to deal with non-standard JSON files:
*
`primitivesAsString` (default `false`): infers all primitive values as a string type
+ *
`floatAsBigDecimal` (default `false`): infers all floating-point values as a decimal
+ * type
*
`allowComments` (default `false`): ignores Java/C++ style comment in JSON records
*
`allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
*
`allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index b0b6995a2214f..bb3cc02800d51 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.stat._
-import org.apache.spark.sql.types.{IntegralType, StringType}
+import org.apache.spark.sql.types._
import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}
/**
@@ -109,7 +109,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* Null elements will be replaced by "null", and back ticks will be dropped from elements if they
* exist.
*
- *
* @param col1 The name of the first column. Distinct items will make the first item of
* each row.
* @param col2 The name of the second column. Distinct items will make the column names
@@ -374,21 +373,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
val singleCol = df.select(col)
val colType = singleCol.schema.head.dataType
- require(
- colType == StringType || colType.isInstanceOf[IntegralType],
- s"Count-min Sketch only supports string type and integral types, " +
- s"and does not support type $colType."
- )
+ val updater: (CountMinSketch, InternalRow) => Unit = colType match {
+ // For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary`
+ // instead of `addString` to avoid unnecessary conversion.
+ case StringType => (sketch, row) => sketch.addBinary(row.getUTF8String(0).getBytes)
+ case ByteType => (sketch, row) => sketch.addLong(row.getByte(0))
+ case ShortType => (sketch, row) => sketch.addLong(row.getShort(0))
+ case IntegerType => (sketch, row) => sketch.addLong(row.getInt(0))
+ case LongType => (sketch, row) => sketch.addLong(row.getLong(0))
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Count-min Sketch only supports string type and integral types, " +
+ s"and does not support type $colType."
+ )
+ }
- singleCol.rdd.aggregate(zero)(
- (sketch: CountMinSketch, row: Row) => {
- sketch.add(row.get(0))
+ singleCol.queryExecution.toRdd.aggregate(zero)(
+ (sketch: CountMinSketch, row: InternalRow) => {
+ updater(sketch, row)
sketch
},
-
- (sketch1: CountMinSketch, sketch2: CountMinSketch) => {
- sketch1.mergeInPlace(sketch2)
- }
+ (sketch1, sketch2) => sketch1.mergeInPlace(sketch2)
)
}
@@ -447,19 +452,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
require(colType == StringType || colType.isInstanceOf[IntegralType],
s"Bloom filter only supports string type and integral types, but got $colType.")
- val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) {
- (filter, row) =>
- // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary`
- // instead of `putString` to avoid unnecessary conversion.
- filter.putBinary(row.getUTF8String(0).getBytes)
- filter
- } else {
- (filter, row) =>
- // TODO: specialize it.
- filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue())
- filter
+ val updater: (BloomFilter, InternalRow) => Unit = colType match {
+ // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary`
+ // instead of `putString` to avoid unnecessary conversion.
+ case StringType => (filter, row) => filter.putBinary(row.getUTF8String(0).getBytes)
+ case ByteType => (filter, row) => filter.putLong(row.getByte(0))
+ case ShortType => (filter, row) => filter.putLong(row.getShort(0))
+ case IntegerType => (filter, row) => filter.putLong(row.getInt(0))
+ case LongType => (filter, row) => filter.putLong(row.getLong(0))
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Bloom filter only supports string type and integral types, " +
+ s"and does not support type $colType."
+ )
}
- singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _)
+ singleCol.queryExecution.toRdd.aggregate(zero)(
+ (filter: BloomFilter, row: InternalRow) => {
+ updater(filter, row)
+ filter
+ },
+ (filter1, filter2) => filter1.mergeInPlace(filter2)
+ )
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 12eb2393634a9..8060198968988 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -22,17 +22,18 @@ import java.util.Properties
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier}
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project}
import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, ResolvedDataSource}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
+import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.sources.HadoopFsRelation
/**
* :: Experimental ::
* Interface used to write a [[DataFrame]] to external storage systems (e.g. file systems,
- * key-value stores, etc). Use [[DataFrame.write]] to access this.
+ * key-value stores, etc) or data streams. Use [[DataFrame.write]] to access this.
*
* @since 1.4.0
*/
@@ -94,6 +95,27 @@ final class DataFrameWriter private[sql](df: DataFrame) {
this
}
+ /**
+ * Adds an output option for the underlying data source.
+ *
+ * @since 2.0.0
+ */
+ def option(key: String, value: Boolean): DataFrameWriter = option(key, value.toString)
+
+ /**
+ * Adds an output option for the underlying data source.
+ *
+ * @since 2.0.0
+ */
+ def option(key: String, value: Long): DataFrameWriter = option(key, value.toString)
+
+ /**
+ * Adds an output option for the underlying data source.
+ *
+ * @since 2.0.0
+ */
+ def option(key: String, value: Double): DataFrameWriter = option(key, value.toString)
+
/**
* (Scala-specific) Adds output options for the underlying data source.
*
@@ -183,6 +205,34 @@ final class DataFrameWriter private[sql](df: DataFrame) {
df)
}
+ /**
+ * Starts the execution of the streaming query, which will continually output results to the given
+ * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with
+ * the stream.
+ *
+ * @since 2.0.0
+ */
+ def stream(path: String): ContinuousQuery = {
+ option("path", path).stream()
+ }
+
+ /**
+ * Starts the execution of the streaming query, which will continually output results to the given
+ * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with
+ * the stream.
+ *
+ * @since 2.0.0
+ */
+ def stream(): ContinuousQuery = {
+ val sink = ResolvedDataSource.createSink(
+ df.sqlContext,
+ source,
+ extraOptions.toMap,
+ normalizedParCols.getOrElse(Nil))
+
+ new StreamExecution(df.sqlContext, df.logicalPlan, sink)
+ }
+
/**
* Inserts the content of the [[DataFrame]] to the specified table. It requires that
* the schema of the [[DataFrame]] is the same as the schema of the table.
@@ -255,7 +305,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
/**
* The given column name may not be equal to any of the existing column names if we were in
- * case-insensitive context. Normalize the given column name to the real one so that we don't
+ * case-insensitive context. Normalize the given column name to the real one so that we don't
* need to care about case sensitivity afterwards.
*/
private def normalize(columnName: String, columnType: String): String = {
@@ -339,7 +389,6 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @param connectionProperties JDBC database connection arguments, a list of arbitrary string
* tag/value. Normally at least a "user" and "password" property
* should be included.
- *
* @since 1.4.0
*/
def jdbc(url: String, table: String, connectionProperties: Properties): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index f182270a08729..378763268acc6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -74,6 +74,7 @@ class Dataset[T] private[sql](
* same object type (that will be possibly resolved to a different schema).
*/
private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
+ unresolvedTEncoder.validate(logicalPlan.output)
/** The encoder for this [[Dataset]] that has been resolved to its output schema. */
private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
@@ -85,7 +86,7 @@ class Dataset[T] private[sql](
*/
private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
- private implicit def classTag = resolvedTEncoder.clsTag
+ private implicit def classTag = unresolvedTEncoder.clsTag
private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
this(sqlContext, new QueryExecution(sqlContext, plan), encoder)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index b3f8284364782..c0e28f2dc5bd6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -116,6 +116,7 @@ class GroupedDataset[K, V] private[sql](
MapGroups(
f,
groupingAttributes,
+ dataAttributes,
logicalPlan))
}
@@ -310,6 +311,8 @@ class GroupedDataset[K, V] private[sql](
f,
this.groupingAttributes,
other.groupingAttributes,
+ this.dataAttributes,
+ other.dataAttributes,
this.logicalPlan,
other.logicalPlan))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index c9ba6700998c3..eb9da0bd4fd4c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.parquet.hadoop.ParquetOutputCommitter
+import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.parser.ParserConf
import org.apache.spark.util.Utils
@@ -519,7 +520,7 @@ private[spark] object SQLConf {
*
* SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads).
*/
-private[sql] class SQLConf extends Serializable with CatalystConf with ParserConf {
+private[sql] class SQLConf extends Serializable with CatalystConf with ParserConf with Logging {
import SQLConf._
/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
@@ -628,7 +629,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon
// Only verify configs in the SQLConf object
entry.valueConverter(value)
}
- settings.put(key, value)
+ setConfWithCheck(key, value)
}
/** Set the given Spark SQL configuration property. */
@@ -636,7 +637,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon
require(entry != null, "entry cannot be null")
require(value != null, s"value cannot be null for key: ${entry.key}")
require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered")
- settings.put(entry.key, entry.stringConverter(value))
+ setConfWithCheck(entry.key, entry.stringConverter(value))
}
/** Return the value of Spark SQL configuration property for the given key. */
@@ -699,6 +700,13 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon
}.toSeq
}
+ private def setConfWithCheck(key: String, value: String): Unit = {
+ if (key.startsWith("spark.") && !key.startsWith("spark.sql.")) {
+ logWarning(s"Attempt to set non-Spark SQL config in SQLConf: key = $key, value = $value")
+ }
+ settings.put(key, value)
+ }
+
private[spark] def unsetConf(key: String): Unit = {
settings.remove(key)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index be28df3a51557..1661fdbec5326 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -206,10 +206,7 @@ class SQLContext private[sql](
@transient
protected[sql] val sqlParser: ParserInterface = new SparkQl(conf)
- @transient
- protected[sql] val ddlParser: DDLParser = new DDLParser(sqlParser)
-
- protected[sql] def parseSql(sql: String): LogicalPlan = ddlParser.parse(sql, false)
+ protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser.parsePlan(sql)
protected[sql] def executeSql(sql: String):
org.apache.spark.sql.execution.QueryExecution = executePlan(parseSql(sql))
@@ -582,10 +579,9 @@ class SQLContext private[sql](
DataFrame(self, LocalRelation(attrSeq, rows.toSeq))
}
-
/**
* :: Experimental ::
- * Returns a [[DataFrameReader]] that can be used to read data in as a [[DataFrame]].
+ * Returns a [[DataFrameReader]] that can be used to read data and streams in as a [[DataFrame]].
* {{{
* sqlContext.read.parquet("/path/to/file.parquet")
* sqlContext.read.schema(schema).json("/path/to/file.json")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index ab414799f1a42..16c4095db722a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -39,6 +39,8 @@ abstract class SQLImplicits {
/** @since 1.6.0 */
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder()
+ // Primitives
+
/** @since 1.6.0 */
implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder()
@@ -56,13 +58,72 @@ abstract class SQLImplicits {
/** @since 1.6.0 */
implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder()
- /** @since 1.6.0 */
+ /** @since 1.6.0 */
implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder()
/** @since 1.6.0 */
implicit def newStringEncoder: Encoder[String] = ExpressionEncoder()
+ // Seqs
+
+ /** @since 1.6.1 */
+ implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()
+
+ // Arrays
+
+ /** @since 1.6.1 */
+ implicit def newIntArrayEncoder: Encoder[Array[Int]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newLongArrayEncoder: Encoder[Array[Long]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newDoubleArrayEncoder: Encoder[Array[Double]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newFloatArrayEncoder: Encoder[Array[Float]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newByteArrayEncoder: Encoder[Array[Byte]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newShortArrayEncoder: Encoder[Array[Short]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newBooleanArrayEncoder: Encoder[Array[Boolean]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newStringArrayEncoder: Encoder[Array[String]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newProductArrayEncoder[A <: Product : TypeTag]: Encoder[Array[A]] =
+ ExpressionEncoder()
+
/**
* Creates a [[Dataset]] from an RDD.
* @since 1.6.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
index b1bbb1da10a39..ea20115770f79 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -17,8 +17,12 @@
package org.apache.spark.sql.execution;
+import java.io.IOException;
+import java.util.LinkedList;
+
import scala.collection.Iterator;
+import org.apache.spark.TaskContext;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -29,36 +33,50 @@
* TODO: replaced it by batched columnar format.
*/
public class BufferedRowIterator {
- protected InternalRow currentRow;
+ protected LinkedList currentRows = new LinkedList<>();
protected Iterator input;
// used when there is no column in output
protected UnsafeRow unsafeRow = new UnsafeRow(0);
- public boolean hasNext() {
- if (currentRow == null) {
+ public boolean hasNext() throws IOException {
+ if (currentRows.isEmpty()) {
processNext();
}
- return currentRow != null;
+ return !currentRows.isEmpty();
}
public InternalRow next() {
- InternalRow r = currentRow;
- currentRow = null;
- return r;
+ return currentRows.remove();
}
public void setInput(Iterator iter) {
input = iter;
}
+ /**
+ * Returns whether `processNext()` should stop processing next row from `input` or not.
+ *
+ * If it returns true, the caller should exit the loop (return from processNext()).
+ */
+ protected boolean shouldStop() {
+ return !currentRows.isEmpty();
+ }
+
+ /**
+ * Increase the peak execution memory for current task.
+ */
+ protected void incPeakExecutionMemory(long size) {
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(size);
+ }
+
/**
* Processes the input until have a row as output (currentRow).
*
* After it's called, if currentRow is still null, it means no more rows left.
*/
- protected void processNext() {
+ protected void processNext() throws IOException {
if (input.hasNext()) {
- currentRow = input.next();
+ currentRows.add(input.next());
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 3770883af1e2f..97f65f18bfdcc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -57,6 +57,69 @@ case class Exchange(
override def output: Seq[Attribute] = child.output
+ private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
+
+ override protected def doPrepare(): Unit = {
+ // If an ExchangeCoordinator is needed, we register this Exchange operator
+ // to the coordinator when we do prepare. It is important to make sure
+ // we register this operator right before the execution instead of register it
+ // in the constructor because it is possible that we create new instances of
+ // Exchange operators when we transform the physical plan
+ // (then the ExchangeCoordinator will hold references of unneeded Exchanges).
+ // So, we should only call registerExchange just before we start to execute
+ // the plan.
+ coordinator match {
+ case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this)
+ case None =>
+ }
+ }
+
+ /**
+ * Returns a [[ShuffleDependency]] that will partition rows of its child based on
+ * the partitioning scheme defined in `newPartitioning`. Those partitions of
+ * the returned ShuffleDependency will be the input of shuffle.
+ */
+ private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = {
+ Exchange.prepareShuffleDependency(child.execute(), child.output, newPartitioning, serializer)
+ }
+
+ /**
+ * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset.
+ * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional
+ * partition start indices array. If this optional array is defined, the returned
+ * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array.
+ */
+ private[sql] def preparePostShuffleRDD(
+ shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow],
+ specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = {
+ // If an array of partition start indices is provided, we need to use this array
+ // to create the ShuffledRowRDD. Also, we need to update newPartitioning to
+ // update the number of post-shuffle partitions.
+ specifiedPartitionStartIndices.foreach { indices =>
+ assert(newPartitioning.isInstanceOf[HashPartitioning])
+ newPartitioning = UnknownPartitioning(indices.length)
+ }
+ new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices)
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+ coordinator match {
+ case Some(exchangeCoordinator) =>
+ val shuffleRDD = exchangeCoordinator.postShuffleRDD(this)
+ assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
+ shuffleRDD
+ case None =>
+ val shuffleDependency = prepareShuffleDependency()
+ preparePostShuffleRDD(shuffleDependency)
+ }
+ }
+}
+
+object Exchange {
+ def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = {
+ Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator])
+ }
+
/**
* Determines whether records must be defensively copied before being sent to the shuffle.
* Several of Spark's shuffle components will buffer deserialized Java objects in memory. The
@@ -82,7 +145,7 @@ case class Exchange(
// passed instead of directly passing the number of partitions in order to guard against
// corner-cases where a partitioner constructed with `numPartitions` partitions may output
// fewer partitions (like RangePartitioner, for example).
- val conf = child.sqlContext.sparkContext.conf
+ val conf = SparkEnv.get.conf
val shuffleManager = SparkEnv.get.shuffleManager
val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager]
val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
@@ -117,30 +180,16 @@ case class Exchange(
}
}
- private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
-
- override protected def doPrepare(): Unit = {
- // If an ExchangeCoordinator is needed, we register this Exchange operator
- // to the coordinator when we do prepare. It is important to make sure
- // we register this operator right before the execution instead of register it
- // in the constructor because it is possible that we create new instances of
- // Exchange operators when we transform the physical plan
- // (then the ExchangeCoordinator will hold references of unneeded Exchanges).
- // So, we should only call registerExchange just before we start to execute
- // the plan.
- coordinator match {
- case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this)
- case None =>
- }
- }
-
/**
* Returns a [[ShuffleDependency]] that will partition rows of its child based on
* the partitioning scheme defined in `newPartitioning`. Those partitions of
* the returned ShuffleDependency will be the input of shuffle.
*/
- private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = {
- val rdd = child.execute()
+ private[sql] def prepareShuffleDependency(
+ rdd: RDD[InternalRow],
+ outputAttributes: Seq[Attribute],
+ newPartitioning: Partitioning,
+ serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = {
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
case HashPartitioning(_, n) =>
@@ -160,7 +209,7 @@ case class Exchange(
// We need to use an interpreted ordering here because generated orderings cannot be
// serialized and this ordering needs to be created on the driver in order to be passed into
// Spark core code.
- implicit val ordering = new InterpretedOrdering(sortingExpressions, child.output)
+ implicit val ordering = new InterpretedOrdering(sortingExpressions, outputAttributes)
new RangePartitioner(numPartitions, rddForSampling, ascending = true)
case SinglePartition =>
new Partitioner {
@@ -180,7 +229,7 @@ case class Exchange(
position
}
case h: HashPartitioning =>
- val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, child.output)
+ val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
row => projection(row).getInt(0)
case RangePartitioning(_, _) | SinglePartition => identity
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
@@ -211,43 +260,6 @@ case class Exchange(
dependency
}
-
- /**
- * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset.
- * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional
- * partition start indices array. If this optional array is defined, the returned
- * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array.
- */
- private[sql] def preparePostShuffleRDD(
- shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow],
- specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = {
- // If an array of partition start indices is provided, we need to use this array
- // to create the ShuffledRowRDD. Also, we need to update newPartitioning to
- // update the number of post-shuffle partitions.
- specifiedPartitionStartIndices.foreach { indices =>
- assert(newPartitioning.isInstanceOf[HashPartitioning])
- newPartitioning = UnknownPartitioning(indices.length)
- }
- new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices)
- }
-
- protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
- coordinator match {
- case Some(exchangeCoordinator) =>
- val shuffleRDD = exchangeCoordinator.postShuffleRDD(this)
- assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
- shuffleRDD
- case None =>
- val shuffleDependency = prepareShuffleDependency()
- preparePostShuffleRDD(shuffleDependency)
- }
- }
-}
-
-object Exchange {
- def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = {
- Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator])
- }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 107570f9dbcc8..8616fe317034f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
/**
* The primary workflow for executing relational queries using Spark. Designed to allow easy
@@ -44,7 +44,7 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
lazy val sparkPlan: SparkPlan = {
SQLContext.setActive(sqlContext)
- sqlContext.planner.plan(optimizedPlan).next()
+ sqlContext.planner.plan(ReturnAnswer(optimizedPlan)).next()
}
// executedPlan should not be used to initialize any SparkPlan. It should be
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index b19b772409d83..3cc99d3c7b1b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -200,47 +200,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
inputSchema: Seq[Attribute],
useSubexprElimination: Boolean = false): () => MutableProjection = {
log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema")
- try {
- GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination)
- } catch {
- case e: Exception =>
- if (isTesting) {
- throw e
- } else {
- log.error("Failed to generate mutable projection, fallback to interpreted", e)
- () => new InterpretedMutableProjection(expressions, inputSchema)
- }
- }
+ GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination)
}
protected def newPredicate(
expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = {
- try {
- GeneratePredicate.generate(expression, inputSchema)
- } catch {
- case e: Exception =>
- if (isTesting) {
- throw e
- } else {
- log.error("Failed to generate predicate, fallback to interpreted", e)
- InterpretedPredicate.create(expression, inputSchema)
- }
- }
+ GeneratePredicate.generate(expression, inputSchema)
}
protected def newOrdering(
order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow] = {
- try {
- GenerateOrdering.generate(order, inputSchema)
- } catch {
- case e: Exception =>
- if (isTesting) {
- throw e
- } else {
- log.error("Failed to generate ordering, fallback to interpreted", e)
- new InterpretedOrdering(order, inputSchema)
- }
- }
+ GenerateOrdering.generate(order, inputSchema)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala
index f6055306b6c97..4174e27e9c8b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala
@@ -16,11 +16,14 @@
*/
package org.apache.spark.sql.execution
+import org.apache.spark.sql.{AnalysisException, SaveMode}
import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.parser.{ASTNode, ParserConf, SimpleParserConf}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.types.StructType
private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends CatalystQl(conf) {
/** Check if a command should not be explained. */
@@ -55,6 +58,89 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly
getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs)
ExplainCommand(nodeToPlan(query), extended = extended.isDefined)
+ case Token("TOK_REFRESHTABLE", nameParts :: Nil) =>
+ val tableIdent = extractTableIdent(nameParts)
+ RefreshTable(tableIdent)
+
+ case Token("TOK_CREATETABLEUSING", createTableArgs) =>
+ val Seq(
+ temp,
+ allowExisting,
+ Some(tabName),
+ tableCols,
+ Some(Token("TOK_TABLEPROVIDER", providerNameParts)),
+ tableOpts,
+ tableAs) = getClauses(Seq(
+ "TEMPORARY",
+ "TOK_IFNOTEXISTS",
+ "TOK_TABNAME", "TOK_TABCOLLIST",
+ "TOK_TABLEPROVIDER",
+ "TOK_TABLEOPTIONS",
+ "TOK_QUERY"), createTableArgs)
+
+ val tableIdent: TableIdentifier = extractTableIdent(tabName)
+
+ val columns = tableCols.map {
+ case Token("TOK_TABCOLLIST", fields) => StructType(fields.map(nodeToStructField))
+ }
+
+ val provider = providerNameParts.map {
+ case Token(name, Nil) => name
+ }.mkString(".")
+
+ val options: Map[String, String] = tableOpts.toSeq.flatMap {
+ case Token("TOK_TABLEOPTIONS", options) =>
+ options.map {
+ case Token("TOK_TABLEOPTION", keysAndValue) =>
+ val key = keysAndValue.init.map(_.text).mkString(".")
+ val value = unquoteString(keysAndValue.last.text)
+ (key, value)
+ }
+ }.toMap
+
+ val asClause = tableAs.map(nodeToPlan(_))
+
+ if (temp.isDefined && allowExisting.isDefined) {
+ throw new AnalysisException(
+ "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.")
+ }
+
+ if (asClause.isDefined) {
+ if (columns.isDefined) {
+ throw new AnalysisException(
+ "a CREATE TABLE AS SELECT statement does not allow column definitions.")
+ }
+
+ val mode = if (allowExisting.isDefined) {
+ SaveMode.Ignore
+ } else if (temp.isDefined) {
+ SaveMode.Overwrite
+ } else {
+ SaveMode.ErrorIfExists
+ }
+
+ CreateTableUsingAsSelect(tableIdent,
+ provider,
+ temp.isDefined,
+ Array.empty[String],
+ bucketSpec = None,
+ mode,
+ options,
+ asClause.get)
+ } else {
+ CreateTableUsing(
+ tableIdent,
+ columns,
+ provider,
+ temp.isDefined,
+ options,
+ allowExisting.isDefined,
+ managedIfNoPath = false)
+ }
+
+ case Token("TOK_SWITCHDATABASE", Token(database, Nil) :: Nil) =>
+ SetDatabaseCommand(cleanIdentifier(database))
+
case Token("TOK_DESCTABLE", describeArgs) =>
// Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
val Some(tableType) :: formatted :: extended :: pretty :: Nil =
@@ -65,26 +151,30 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly
nodeToDescribeFallback(node)
} else {
tableType match {
- case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts :: Nil) :: Nil) =>
+ case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts) :: Nil) =>
nameParts match {
- case Token(".", dbName :: tableName :: Nil) =>
+ case Token(dbName, Nil) :: Token(tableName, Nil) :: Nil =>
// It is describing a table with the format like "describe db.table".
// TODO: Actually, a user may mean tableName.columnName. Need to resolve this
// issue.
- val tableIdent = extractTableIdent(nameParts)
+ val tableIdent = TableIdentifier(
+ cleanIdentifier(tableName), Some(cleanIdentifier(dbName)))
datasources.DescribeCommand(
UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined)
- case Token(".", dbName :: tableName :: colName :: Nil) =>
+ case Token(dbName, Nil) :: Token(tableName, Nil) :: Token(colName, Nil) :: Nil =>
// It is describing a column with the format like "describe db.table column".
nodeToDescribeFallback(node)
- case tableName =>
+ case tableName :: Nil =>
// It is describing a table with the format like "describe table".
datasources.DescribeCommand(
- UnresolvedRelation(TableIdentifier(tableName.text), None),
+ UnresolvedRelation(TableIdentifier(cleanIdentifier(tableName.text)), None),
isExtended = extended.isDefined)
+ case _ =>
+ nodeToDescribeFallback(node)
}
// All other cases.
- case _ => nodeToDescribeFallback(node)
+ case _ =>
+ nodeToDescribeFallback(node)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 60fbb595e5758..ee392e4e8debe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -298,16 +298,20 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Distinct(child) =>
throw new IllegalStateException(
"logical distinct operator should have been replaced by aggregate in the optimizer")
+ case logical.Intersect(left, right) =>
+ throw new IllegalStateException(
+ "logical intersect operator should have been replaced by semi-join in the optimizer")
case logical.MapPartitions(f, in, out, child) =>
execution.MapPartitions(f, in, out, planLater(child)) :: Nil
case logical.AppendColumns(f, in, out, child) =>
execution.AppendColumns(f, in, out, planLater(child)) :: Nil
- case logical.MapGroups(f, key, in, out, grouping, child) =>
- execution.MapGroups(f, key, in, out, grouping, planLater(child)) :: Nil
- case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, left, right) =>
+ case logical.MapGroups(f, key, in, out, grouping, data, child) =>
+ execution.MapGroups(f, key, in, out, grouping, data, planLater(child)) :: Nil
+ case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, left, right) =>
execution.CoGroup(
- f, keyObj, lObj, rObj, out, lGroup, rGroup, planLater(left), planLater(right)) :: Nil
+ f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr,
+ planLater(left), planLater(right)) :: Nil
case logical.Repartition(numPartitions, shuffle, child) =>
if (shuffle) {
@@ -334,14 +338,16 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
LocalTableScan(output, data) :: Nil
+ case logical.ReturnAnswer(logical.Limit(IntegerLiteral(limit), child)) =>
+ execution.CollectLimit(limit, planLater(child)) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
- execution.Limit(limit, planLater(child)) :: Nil
+ val perPartitionLimit = execution.LocalLimit(limit, planLater(child))
+ val globalLimit = execution.GlobalLimit(limit, perPartitionLimit)
+ globalLimit :: Nil
case logical.Union(unionChildren) =>
execution.Union(unionChildren.map(planLater)) :: Nil
case logical.Except(left, right) =>
execution.Except(planLater(left), planLater(right)) :: Nil
- case logical.Intersect(left, right) =>
- execution.Intersect(planLater(left), planLater(right)) :: Nil
case g @ logical.Generate(generator, join, outer, _, _, child) =>
execution.Generate(
generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
@@ -356,6 +362,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
+ case logical.ReturnAnswer(child) => planLater(child) :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 57f4945de9804..131efea20f31e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -22,15 +22,25 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.aggregate.TungstenAggregate
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
+import org.apache.spark.util.Utils
/**
* An interface for those physical operators that support codegen.
*/
trait CodegenSupport extends SparkPlan {
+ /** Prefix used in the current operator's variable names. */
+ private def variablePrefix: String = this match {
+ case _: TungstenAggregate => "agg"
+ case _ => nodeName.toLowerCase
+ }
+
/**
* Whether this SparkPlan support whole stage codegen or not.
*/
@@ -42,10 +52,16 @@ trait CodegenSupport extends SparkPlan {
private var parent: CodegenSupport = null
/**
- * Returns an input RDD of InternalRow and Java source code to process them.
+ * Returns the RDD of InternalRow which generates the input rows.
+ */
+ def upstream(): RDD[InternalRow]
+
+ /**
+ * Returns Java source code to process the rows from upstream.
*/
- def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = {
+ def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
this.parent = parent
+ ctx.freshNamePrefix = variablePrefix
doProduce(ctx)
}
@@ -66,16 +82,41 @@ trait CodegenSupport extends SparkPlan {
* # call consume(), wich will call parent.doConsume()
* }
*/
- protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String)
+ protected def doProduce(ctx: CodegenContext): String
/**
- * Consume the columns generated from current SparkPlan, call it's parent or create an iterator.
+ * Consume the columns generated from current SparkPlan, call it's parent.
*/
- protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = {
- assert(columns.length == output.length)
- parent.doConsume(ctx, this, columns)
+ def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = {
+ if (input != null) {
+ assert(input.length == output.length)
+ }
+ parent.consumeChild(ctx, this, input, row)
}
+ /**
+ * Consume the columns generated from it's child, call doConsume() or emit the rows.
+ */
+ def consumeChild(
+ ctx: CodegenContext,
+ child: SparkPlan,
+ input: Seq[ExprCode],
+ row: String = null): String = {
+ ctx.freshNamePrefix = variablePrefix
+ if (row != null) {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = row
+ val evals = child.output.zipWithIndex.map { case (attr, i) =>
+ BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
+ }
+ s"""
+ | ${evals.map(_.code).mkString("\n")}
+ | ${doConsume(ctx, evals)}
+ """.stripMargin
+ } else {
+ doConsume(ctx, input)
+ }
+ }
/**
* Generate the Java source code to process the rows from child SparkPlan.
@@ -89,7 +130,9 @@ trait CodegenSupport extends SparkPlan {
* # call consume(), which will call parent.doConsume()
* }
*/
- def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
+ protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ throw new UnsupportedOperationException
+ }
}
@@ -102,31 +145,39 @@ trait CodegenSupport extends SparkPlan {
case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+ override def doPrepare(): Unit = {
+ child.prepare()
+ }
- override def supportCodegen: Boolean = true
+ override def doExecute(): RDD[InternalRow] = {
+ child.execute()
+ }
+
+ override def supportCodegen: Boolean = false
- override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ override def upstream(): RDD[InternalRow] = {
+ child.execute()
+ }
+
+ override def doProduce(ctx: CodegenContext): String = {
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
val row = ctx.freshName("row")
ctx.INPUT_ROW = row
ctx.currentVars = null
val columns = exprs.map(_.gen(ctx))
- val code = s"""
- | while (input.hasNext()) {
+ s"""
+ | while (input.hasNext()) {
| InternalRow $row = (InternalRow) input.next();
- | ${columns.map(_.code).mkString("\n")}
- | ${consume(ctx, columns)}
+ | ${columns.map(_.code).mkString("\n").trim}
+ | ${consume(ctx, columns).trim}
+ | if (shouldStop()) {
+ | return;
+ | }
| }
""".stripMargin
- (child.execute(), code)
- }
-
- def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
- throw new UnsupportedOperationException
- }
-
- override def doExecute(): RDD[InternalRow] = {
- throw new UnsupportedOperationException
}
override def simpleString: String = "INPUT"
@@ -143,16 +194,20 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
*
* -> execute()
* |
- * doExecute() --------> produce()
+ * doExecute() ---------> upstream() -------> upstream() ------> execute()
+ * |
+ * -----------------> produce()
* |
* doProduce() -------> produce()
* |
- * doProduce() ---> execute()
+ * doProduce()
* |
* consume()
- * doConsume() ------------|
+ * consumeChild() <-----------|
+ * |
+ * doConsume()
* |
- * doConsume() <----- consume()
+ * consumeChild() <----- consume()
*
* SparkPlan A should override doProduce() and doConsume().
*
@@ -162,37 +217,49 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
extends SparkPlan with CodegenSupport {
+ override def supportCodegen: Boolean = false
+
override def output: Seq[Attribute] = plan.output
+ override def outputPartitioning: Partitioning = plan.outputPartitioning
+ override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
+
+ override def doPrepare(): Unit = {
+ plan.prepare()
+ }
override def doExecute(): RDD[InternalRow] = {
val ctx = new CodegenContext
- val (rdd, code) = plan.produce(ctx, this)
+ val code = plan.produce(ctx, this)
val references = ctx.references.toArray
val source = s"""
public Object generate(Object[] references) {
- return new GeneratedIterator(references);
+ return new GeneratedIterator(references);
}
class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
- private Object[] references;
- ${ctx.declareMutableStates()}
+ private Object[] references;
+ ${ctx.declareMutableStates()}
+
+ public GeneratedIterator(Object[] references) {
+ this.references = references;
+ ${ctx.initMutableStates()}
+ }
- public GeneratedIterator(Object[] references) {
- this.references = references;
- ${ctx.initMutableStates()}
- }
+ ${ctx.declareAddedFunctions()}
- protected void processNext() {
- $code
- }
+ protected void processNext() throws java.io.IOException {
+ ${code.trim}
+ }
}
- """
+ """
+
// try to compile, helpful for debug
// println(s"${CodeFormatter.format(source)}")
CodeGenerator.compile(source)
- rdd.mapPartitions { iter =>
+ plan.upstream().mapPartitions { iter =>
+
val clazz = CodeGenerator.compile(source)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.setInput(iter)
@@ -203,29 +270,44 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
}
}
- override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ override def upstream(): RDD[InternalRow] = {
throw new UnsupportedOperationException
}
- override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
- if (input.nonEmpty) {
- val colExprs = output.zipWithIndex.map { case (attr, i) =>
- BoundReference(i, attr.dataType, attr.nullable)
- }
- // generate the code to create a UnsafeRow
- ctx.currentVars = input
- val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
- s"""
- | ${code.code.trim}
- | currentRow = ${code.value};
- | return;
- """.stripMargin
- } else {
- // There is no columns
+ override def doProduce(ctx: CodegenContext): String = {
+ throw new UnsupportedOperationException
+ }
+
+ override def consumeChild(
+ ctx: CodegenContext,
+ child: SparkPlan,
+ input: Seq[ExprCode],
+ row: String = null): String = {
+
+ if (row != null) {
+ // There is an UnsafeRow already
s"""
- | currentRow = unsafeRow;
- | return;
+ | currentRows.add($row.copy());
""".stripMargin
+ } else {
+ assert(input != null)
+ if (input.nonEmpty) {
+ val colExprs = output.zipWithIndex.map { case (attr, i) =>
+ BoundReference(i, attr.dataType, attr.nullable)
+ }
+ // generate the code to create a UnsafeRow
+ ctx.currentVars = input
+ val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
+ s"""
+ | ${code.code.trim}
+ | currentRows.add(${code.value}.copy());
+ """.stripMargin
+ } else {
+ // There is no columns
+ s"""
+ | currentRows.add(unsafeRow);
+ """.stripMargin
+ }
}
}
@@ -246,7 +328,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
builder.append(simpleString)
builder.append("\n")
- plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder)
+ plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
if (children.nonEmpty) {
children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
@@ -286,13 +368,19 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
case plan: CodegenSupport if supportCodegen(plan) &&
// Whole stage codegen is only useful when there are at least two levels of operators that
// support it (save at least one projection/iterator).
- plan.children.exists(supportCodegen) =>
+ (Utils.isTesting || plan.children.exists(supportCodegen)) =>
var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
+ // The build side can't be compiled together
+ case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) =>
+ b.copy(left = apply(left))
+ case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) =>
+ b.copy(right = apply(right))
case p if !supportCodegen(p) =>
- inputs += p
- InputAdapter(p)
+ val input = apply(p) // collapse them recursively
+ inputs += input
+ InputAdapter(input)
}.asInstanceOf[CodegenSupport]
WholeStageCodegen(combined, inputs)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 26a7340f1ae10..84154a47de393 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -198,7 +198,8 @@ case class Window(
functions,
ordinal,
child.output,
- (expressions, schema) => newMutableProjection(expressions, schema))
+ (expressions, schema) =>
+ newMutableProjection(expressions, schema, subexpressionEliminationEnabled))
// Create the factory
val factory = key match {
@@ -210,7 +211,8 @@ case class Window(
ordinal,
functions,
child.output,
- (expressions, schema) => newMutableProjection(expressions, schema),
+ (expressions, schema) =>
+ newMutableProjection(expressions, schema, subexpressionEliminationEnabled),
offset)
// Growing Frame.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 0c74df0aa5fdd..38da82c47ce15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -238,7 +238,7 @@ abstract class AggregationIterator(
resultProjection(joinedRow(currentGroupingKey, currentBuffer))
}
} else {
- // Grouping-only: we only output values of grouping expressions.
+ // Grouping-only: we only output values based on grouping expressions.
val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
(currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => {
resultProjection(currentGroupingKey)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 23e54f344d252..9d9f14f2dd014 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -17,16 +17,18 @@
package org.apache.spark.sql.execution.aggregate
+import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
+import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.KVIterator
case class TungstenAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]],
@@ -114,20 +116,38 @@ case class TungstenAggregate(
}
}
+ // all the mode of aggregate expressions
+ private val modes = aggregateExpressions.map(_.mode).distinct
+
override def supportCodegen: Boolean = {
- groupingExpressions.isEmpty &&
- // ImperativeAggregate is not supported right now
- !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) &&
- // final aggregation only have one row, do not need to codegen
- !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete)
+ // ImperativeAggregate is not supported right now
+ !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
+ }
+
+ override def upstream(): RDD[InternalRow] = {
+ child.asInstanceOf[CodegenSupport].upstream()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ if (groupingExpressions.isEmpty) {
+ doProduceWithoutKeys(ctx)
+ } else {
+ doProduceWithKeys(ctx)
+ }
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ if (groupingExpressions.isEmpty) {
+ doConsumeWithoutKeys(ctx, input)
+ } else {
+ doConsumeWithKeys(ctx, input)
+ }
}
// The variables used as aggregation buffer
private var bufVars: Seq[ExprCode] = _
- private val modes = aggregateExpressions.map(_.mode).distinct
-
- protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ private def doProduceWithoutKeys(ctx: CodegenContext): String = {
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
@@ -137,63 +157,404 @@ case class TungstenAggregate(
bufVars = initExpr.map { e =>
val isNull = ctx.freshName("bufIsNull")
val value = ctx.freshName("bufValue")
+ ctx.addMutableState("boolean", isNull, "")
+ ctx.addMutableState(ctx.javaType(e.dataType), value, "")
// The initial expression should not access any column
val ev = e.gen(ctx)
val initVars = s"""
- | boolean $isNull = ${ev.isNull};
- | ${ctx.javaType(e.dataType)} $value = ${ev.value};
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
""".stripMargin
ExprCode(ev.code + initVars, isNull, value)
}
- val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this)
- val source =
+ // generate variables for output
+ val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
+ val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
+ // evaluate aggregate results
+ ctx.currentVars = bufVars
+ val aggResults = functions.map(_.evaluateExpression).map { e =>
+ BindReferences.bindReference(e, bufferAttrs).gen(ctx)
+ }
+ // evaluate result expressions
+ ctx.currentVars = aggResults
+ val resultVars = resultExpressions.map { e =>
+ BindReferences.bindReference(e, aggregateAttributes).gen(ctx)
+ }
+ (resultVars, s"""
+ | ${aggResults.map(_.code).mkString("\n")}
+ | ${resultVars.map(_.code).mkString("\n")}
+ """.stripMargin)
+ } else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
+ // output the aggregate buffer directly
+ (bufVars, "")
+ } else {
+ // no aggregate function, the result should be literals
+ val resultVars = resultExpressions.map(_.gen(ctx))
+ (resultVars, resultVars.map(_.code).mkString("\n"))
+ }
+
+ val doAgg = ctx.freshName("doAggregateWithoutKey")
+ ctx.addNewFunction(doAgg,
s"""
- | if (!$initAgg) {
- | $initAgg = true;
- |
+ | private void $doAgg() throws java.io.IOException {
| // initialize aggregation buffer
| ${bufVars.map(_.code).mkString("\n")}
|
- | $childSource
- |
- | // output the result
- | ${consume(ctx, bufVars)}
+ | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
| }
- """.stripMargin
+ """.stripMargin)
- (rdd, source)
+ s"""
+ | if (!$initAgg) {
+ | $initAgg = true;
+ | $doAgg();
+ |
+ | // output the result
+ | ${genResult.trim}
+ |
+ | ${consume(ctx, resultVars).trim}
+ | }
+ """.stripMargin
}
- override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+ private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
- // the mode could be only Partial or PartialMerge
- val updateExpr = if (modes.contains(Partial)) {
- functions.flatMap(_.updateExpressions)
- } else {
- functions.flatMap(_.mergeExpressions)
+ val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
+ val updateExpr = aggregateExpressions.flatMap { e =>
+ e.mode match {
+ case Partial | Complete =>
+ e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
+ case PartialMerge | Final =>
+ e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
+ }
}
-
- val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output
- val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr))
ctx.currentVars = bufVars ++ input
// TODO: support subexpression elimination
- val codes = boundExpr.zipWithIndex.map { case (e, i) =>
- val ev = e.gen(ctx)
+ val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx))
+ // aggregate buffer should be updated atomic
+ val updates = aggVals.zipWithIndex.map { case (ev, i) =>
s"""
- | ${ev.code}
| ${bufVars(i).isNull} = ${ev.isNull};
| ${bufVars(i).value} = ${ev.value};
""".stripMargin
}
-
s"""
- | // do aggregate and update aggregation buffer
- | ${codes.mkString("")}
+ | // do aggregate
+ | ${aggVals.map(_.code).mkString("\n").trim}
+ | // update aggregation buffer
+ | ${updates.mkString("\n").trim}
""".stripMargin
}
+ private val groupingAttributes = groupingExpressions.map(_.toAttribute)
+ private val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
+ private val declFunctions = aggregateExpressions.map(_.aggregateFunction)
+ .filter(_.isInstanceOf[DeclarativeAggregate])
+ .map(_.asInstanceOf[DeclarativeAggregate])
+ private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes)
+ private val bufferSchema = StructType.fromAttributes(bufferAttributes)
+
+ // The name for HashMap
+ private var hashMapTerm: String = _
+ private var sorterTerm: String = _
+
+ /**
+ * This is called by generated Java class, should be public.
+ */
+ def createHashMap(): UnsafeFixedWidthAggregationMap = {
+ // create initialized aggregate buffer
+ val initExpr = declFunctions.flatMap(f => f.initialValues)
+ val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow)
+
+ // create hashMap
+ new UnsafeFixedWidthAggregationMap(
+ initialBuffer,
+ bufferSchema,
+ groupingKeySchema,
+ TaskContext.get().taskMemoryManager(),
+ 1024 * 16, // initial capacity
+ TaskContext.get().taskMemoryManager().pageSizeBytes,
+ false // disable tracking of performance metrics
+ )
+ }
+
+ /**
+ * This is called by generated Java class, should be public.
+ */
+ def createUnsafeJoiner(): UnsafeRowJoiner = {
+ GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
+ }
+
+ /**
+ * Called by generated Java class to finish the aggregate and return a KVIterator.
+ */
+ def finishAggregate(
+ hashMap: UnsafeFixedWidthAggregationMap,
+ sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = {
+
+ // update peak execution memory
+ val mapMemory = hashMap.getPeakMemoryUsedBytes
+ val sorterMemory = Option(sorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
+ val peakMemory = Math.max(mapMemory, sorterMemory)
+ val metrics = TaskContext.get().taskMetrics()
+ metrics.incPeakExecutionMemory(peakMemory)
+
+ if (sorter == null) {
+ // not spilled
+ return hashMap.iterator()
+ }
+
+ // merge the final hashMap into sorter
+ sorter.merge(hashMap.destructAndCreateExternalSorter())
+ hashMap.free()
+ val sortedIter = sorter.sortedIterator()
+
+ // Create a KVIterator based on the sorted iterator.
+ new KVIterator[UnsafeRow, UnsafeRow] {
+
+ // Create a MutableProjection to merge the rows of same key together
+ val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
+ val mergeProjection = newMutableProjection(
+ mergeExpr,
+ bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
+ subexpressionEliminationEnabled)()
+ val joinedRow = new JoinedRow()
+
+ var currentKey: UnsafeRow = null
+ var currentRow: UnsafeRow = null
+ var nextKey: UnsafeRow = if (sortedIter.next()) {
+ sortedIter.getKey
+ } else {
+ null
+ }
+
+ override def next(): Boolean = {
+ if (nextKey != null) {
+ currentKey = nextKey.copy()
+ currentRow = sortedIter.getValue.copy()
+ nextKey = null
+ // use the first row as aggregate buffer
+ mergeProjection.target(currentRow)
+
+ // merge the following rows with same key together
+ var findNextGroup = false
+ while (!findNextGroup && sortedIter.next()) {
+ val key = sortedIter.getKey
+ if (currentKey.equals(key)) {
+ mergeProjection(joinedRow(currentRow, sortedIter.getValue))
+ } else {
+ // We find a new group.
+ findNextGroup = true
+ nextKey = key
+ }
+ }
+
+ true
+ } else {
+ false
+ }
+ }
+
+ override def getKey: UnsafeRow = currentKey
+ override def getValue: UnsafeRow = currentRow
+ override def close(): Unit = {
+ sortedIter.close()
+ }
+ }
+ }
+
+ /**
+ * Generate the code for output.
+ */
+ private def generateResultCode(
+ ctx: CodegenContext,
+ keyTerm: String,
+ bufferTerm: String,
+ plan: String): String = {
+ if (modes.contains(Final) || modes.contains(Complete)) {
+ // generate output using resultExpressions
+ ctx.currentVars = null
+ ctx.INPUT_ROW = keyTerm
+ val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
+ BoundReference(i, e.dataType, e.nullable).gen(ctx)
+ }
+ ctx.INPUT_ROW = bufferTerm
+ val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
+ BoundReference(i, e.dataType, e.nullable).gen(ctx)
+ }
+ // evaluate the aggregation result
+ ctx.currentVars = bufferVars
+ val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
+ BindReferences.bindReference(e, bufferAttributes).gen(ctx)
+ }
+ // generate the final result
+ ctx.currentVars = keyVars ++ aggResults
+ val inputAttrs = groupingAttributes ++ aggregateAttributes
+ val resultVars = resultExpressions.map { e =>
+ BindReferences.bindReference(e, inputAttrs).gen(ctx)
+ }
+ s"""
+ ${keyVars.map(_.code).mkString("\n")}
+ ${bufferVars.map(_.code).mkString("\n")}
+ ${aggResults.map(_.code).mkString("\n")}
+ ${resultVars.map(_.code).mkString("\n")}
+
+ ${consume(ctx, resultVars)}
+ """
+
+ } else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
+ // This should be the last operator in a stage, we should output UnsafeRow directly
+ val joinerTerm = ctx.freshName("unsafeRowJoiner")
+ ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm,
+ s"$joinerTerm = $plan.createUnsafeJoiner();")
+ val resultRow = ctx.freshName("resultRow")
+ s"""
+ UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm);
+ ${consume(ctx, null, resultRow)}
+ """
+
+ } else {
+ // generate result based on grouping key
+ ctx.INPUT_ROW = keyTerm
+ ctx.currentVars = null
+ val eval = resultExpressions.map{ e =>
+ BindReferences.bindReference(e, groupingAttributes).gen(ctx)
+ }
+ s"""
+ ${eval.map(_.code).mkString("\n")}
+ ${consume(ctx, eval)}
+ """
+ }
+ }
+
+ private def doProduceWithKeys(ctx: CodegenContext): String = {
+ val initAgg = ctx.freshName("initAgg")
+ ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
+
+ // create hashMap
+ val thisPlan = ctx.addReferenceObj("plan", this)
+ hashMapTerm = ctx.freshName("hashMap")
+ val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
+ ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
+ sorterTerm = ctx.freshName("sorter")
+ ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "")
+
+ // Create a name for iterator from HashMap
+ val iterTerm = ctx.freshName("mapIter")
+ ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")
+
+ val doAgg = ctx.freshName("doAggregateWithKeys")
+ ctx.addNewFunction(doAgg,
+ s"""
+ private void $doAgg() throws java.io.IOException {
+ ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
+
+ $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm);
+ }
+ """)
+
+ // generate code for output
+ val keyTerm = ctx.freshName("aggKey")
+ val bufferTerm = ctx.freshName("aggBuffer")
+ val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan)
+
+ s"""
+ if (!$initAgg) {
+ $initAgg = true;
+ $doAgg();
+ }
+
+ // output the result
+ while ($iterTerm.next()) {
+ UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
+ UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
+ $outputCode
+
+ if (shouldStop()) return;
+ }
+
+ $iterTerm.close();
+ if ($sorterTerm == null) {
+ $hashMapTerm.free();
+ }
+ """
+ }
+
+ private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+
+ // create grouping key
+ ctx.currentVars = input
+ val keyCode = GenerateUnsafeProjection.createCode(
+ ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
+ val key = keyCode.value
+ val buffer = ctx.freshName("aggBuffer")
+
+ // only have DeclarativeAggregate
+ val updateExpr = aggregateExpressions.flatMap { e =>
+ e.mode match {
+ case Partial | Complete =>
+ e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
+ case PartialMerge | Final =>
+ e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
+ }
+ }
+
+ val inputAttr = bufferAttributes ++ child.output
+ ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
+ ctx.INPUT_ROW = buffer
+ // TODO: support subexpression elimination
+ val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx))
+ val updates = evals.zipWithIndex.map { case (ev, i) =>
+ val dt = updateExpr(i).dataType
+ ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable)
+ }
+
+ val (checkFallback, resetCoulter, incCounter) = if (testFallbackStartsAt.isDefined) {
+ val countTerm = ctx.freshName("fallbackCounter")
+ ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
+ (s"$countTerm < ${testFallbackStartsAt.get}", s"$countTerm = 0;", s"$countTerm += 1;")
+ } else {
+ ("true", "", "")
+ }
+
+ // We try to do hash map based in-memory aggregation first. If there is not enough memory (the
+ // hash map will return null for new key), we spill the hash map to disk to free memory, then
+ // continue to do in-memory aggregation and spilling until all the rows had been processed.
+ // Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
+ s"""
+ // generate grouping key
+ ${keyCode.code.trim}
+ UnsafeRow $buffer = null;
+ if ($checkFallback) {
+ // try to get the buffer from hash map
+ $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
+ }
+ if ($buffer == null) {
+ if ($sorterTerm == null) {
+ $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
+ } else {
+ $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
+ }
+ $resetCoulter
+ // the hash map had be spilled, it should have enough memory now,
+ // try to allocate buffer again.
+ $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
+ if ($buffer == null) {
+ // failed to allocate the first page
+ throw new OutOfMemoryError("No enough memory for aggregation");
+ }
+ }
+ $incCounter
+
+ // evaluate aggregate function
+ ${evals.map(_.code).mkString("\n").trim}
+ // update aggregate buffer
+ ${updates.mkString("\n").trim}
+ """
+ }
+
override def simpleString: String = {
val allAggregateExpressions = aggregateExpressions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 5a19920add717..812e696338362 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, InterpretedMutableProjection, MutableRow}
-import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, ImperativeAggregate}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MutableRow}
+import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
@@ -361,13 +361,7 @@ private[sql] case class ScalaUDAF(
val inputAttributes = childrenSchema.toAttributes
log.debug(
s"Creating MutableProj: $children, inputSchema: $inputAttributes.")
- try {
- GenerateMutableProjection.generate(children, inputAttributes)()
- } catch {
- case e: Exception =>
- log.error("Failed to generate mutable projection, fallback to interpreted", e)
- new InterpretedMutableProjection(children, inputAttributes)
- }
+ GenerateMutableProjection.generate(children, inputAttributes)()
}
private[this] lazy val inputToScalaConverters: Any => Any =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 83379ae90f703..1e113ccd4e137 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -33,15 +33,14 @@ object Utils {
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
- val groupingAttributes = groupingExpressions.map(_.toAttribute)
val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
val completeAggregateAttributes = completeAggregateExpressions.map {
expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
}
SortBasedAggregate(
- requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes,
+ requiredChildDistributionExpressions = Some(groupingExpressions),
+ groupingExpressions = groupingExpressions,
aggregateExpressions = completeAggregateExpressions,
aggregateAttributes = completeAggregateAttributes,
initialInputBufferOffset = 0,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 6deb72adad5ec..f63e8a9b6d79d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -17,16 +17,13 @@
package org.apache.spark.sql.execution
-import org.apache.spark.{HashPartitioner, SparkEnv}
-import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
-import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.LongType
-import org.apache.spark.util.MutablePair
import org.apache.spark.util.random.PoissonSampler
case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
@@ -37,11 +34,15 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
- protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ override def upstream(): RDD[InternalRow] = {
+ child.asInstanceOf[CodegenSupport].upstream()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val exprs = projectList.map(x =>
ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
ctx.currentVars = input
@@ -76,18 +77,27 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
- protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ override def upstream(): RDD[InternalRow] = {
+ child.asInstanceOf[CodegenSupport].upstream()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val expr = ExpressionCanonicalizer.execute(
BindReferences.bindReference(condition, child.output))
ctx.currentVars = input
val eval = expr.gen(ctx)
+ val nullCheck = if (expr.nullable) {
+ s"!${eval.isNull} &&"
+ } else {
+ s""
+ }
s"""
| ${eval.code}
- | if (!${eval.isNull} && ${eval.value}) {
+ | if ($nullCheck ${eval.value}) {
| ${consume(ctx, ctx.currentVars)}
| }
""".stripMargin
@@ -153,17 +163,21 @@ case class Range(
output: Seq[Attribute])
extends LeafNode with CodegenSupport {
- protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
- val initTerm = ctx.freshName("range_initRange")
+ override def upstream(): RDD[InternalRow] = {
+ sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ val initTerm = ctx.freshName("initRange")
ctx.addMutableState("boolean", initTerm, s"$initTerm = false;")
- val partitionEnd = ctx.freshName("range_partitionEnd")
+ val partitionEnd = ctx.freshName("partitionEnd")
ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;")
- val number = ctx.freshName("range_number")
+ val number = ctx.freshName("number")
ctx.addMutableState("long", number, s"$number = 0L;")
- val overflow = ctx.freshName("range_overflow")
+ val overflow = ctx.freshName("overflow")
ctx.addMutableState("boolean", overflow, s"$overflow = false;")
- val value = ctx.freshName("range_value")
+ val value = ctx.freshName("value")
val ev = ExprCode("", "false", value)
val BigInt = classOf[java.math.BigInteger].getName
val checkEnd = if (step > 0) {
@@ -172,38 +186,42 @@ case class Range(
s"$number > $partitionEnd"
}
- val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
- .map(i => InternalRow(i))
+ ctx.addNewFunction("initRange",
+ s"""
+ | private void initRange(int idx) {
+ | $BigInt index = $BigInt.valueOf(idx);
+ | $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
+ | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
+ | $BigInt step = $BigInt.valueOf(${step}L);
+ | $BigInt start = $BigInt.valueOf(${start}L);
+ |
+ | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
+ | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
+ | $number = Long.MAX_VALUE;
+ | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
+ | $number = Long.MIN_VALUE;
+ | } else {
+ | $number = st.longValue();
+ | }
+ |
+ | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
+ | .multiply(step).add(start);
+ | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
+ | $partitionEnd = Long.MAX_VALUE;
+ | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
+ | $partitionEnd = Long.MIN_VALUE;
+ | } else {
+ | $partitionEnd = end.longValue();
+ | }
+ | }
+ """.stripMargin)
- val code = s"""
+ s"""
| // initialize Range
| if (!$initTerm) {
| $initTerm = true;
| if (input.hasNext()) {
- | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0));
- | $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
- | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
- | $BigInt step = $BigInt.valueOf(${step}L);
- | $BigInt start = $BigInt.valueOf(${start}L);
- |
- | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
- | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
- | $number = Long.MAX_VALUE;
- | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
- | $number = Long.MIN_VALUE;
- | } else {
- | $number = st.longValue();
- | }
- |
- | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
- | .multiply(step).add(start);
- | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
- | $partitionEnd = Long.MAX_VALUE;
- | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
- | $partitionEnd = Long.MIN_VALUE;
- | } else {
- | $partitionEnd = end.longValue();
- | }
+ | initRange(((InternalRow) input.next()).getInt(0));
| } else {
| return;
| }
@@ -216,14 +234,10 @@ case class Range(
| $overflow = true;
| }
| ${consume(ctx, Seq(ev))}
+ |
+ | if (shouldStop()) return;
| }
""".stripMargin
-
- (rdd, code)
- }
-
- def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
- throw new UnsupportedOperationException
}
protected override def doExecute(): RDD[InternalRow] = {
@@ -289,96 +303,6 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan {
sparkContext.union(children.map(_.execute()))
}
-/**
- * Take the first limit elements. Note that the implementation is different depending on whether
- * this is a terminal operator or not. If it is terminal and is invoked using executeCollect,
- * this operator uses something similar to Spark's take method on the Spark driver. If it is not
- * terminal or is invoked using execute, we first take the limit on each partition, and then
- * repartition all the data to a single partition to compute the global limit.
- */
-case class Limit(limit: Int, child: SparkPlan)
- extends UnaryNode {
- // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
- // partition local limit -> exchange into one partition -> partition local limit again
-
- /** We must copy rows when sort based shuffle is on */
- private def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
-
- override def output: Seq[Attribute] = child.output
- override def outputPartitioning: Partitioning = SinglePartition
-
- override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
-
- protected override def doExecute(): RDD[InternalRow] = {
- val rdd: RDD[_ <: Product2[Boolean, InternalRow]] = if (sortBasedShuffleOn) {
- child.execute().mapPartitionsInternal { iter =>
- iter.take(limit).map(row => (false, row.copy()))
- }
- } else {
- child.execute().mapPartitionsInternal { iter =>
- val mutablePair = new MutablePair[Boolean, InternalRow]()
- iter.take(limit).map(row => mutablePair.update(false, row))
- }
- }
- val part = new HashPartitioner(1)
- val shuffled = new ShuffledRDD[Boolean, InternalRow, InternalRow](rdd, part)
- shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf))
- shuffled.mapPartitionsInternal(_.take(limit).map(_._2))
- }
-}
-
-/**
- * Take the first limit elements as defined by the sortOrder, and do projection if needed.
- * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator,
- * or having a [[Project]] operator between them.
- * This could have been named TopK, but Spark's top operator does the opposite in ordering
- * so we name it TakeOrdered to avoid confusion.
- */
-case class TakeOrderedAndProject(
- limit: Int,
- sortOrder: Seq[SortOrder],
- projectList: Option[Seq[NamedExpression]],
- child: SparkPlan) extends UnaryNode {
-
- override def output: Seq[Attribute] = {
- val projectOutput = projectList.map(_.map(_.toAttribute))
- projectOutput.getOrElse(child.output)
- }
-
- override def outputPartitioning: Partitioning = SinglePartition
-
- // We need to use an interpreted ordering here because generated orderings cannot be serialized
- // and this ordering needs to be created on the driver in order to be passed into Spark core code.
- private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output)
-
- private def collectData(): Array[InternalRow] = {
- val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
- if (projectList.isDefined) {
- val proj = UnsafeProjection.create(projectList.get, child.output)
- data.map(r => proj(r).copy())
- } else {
- data
- }
- }
-
- override def executeCollect(): Array[InternalRow] = {
- collectData()
- }
-
- // TODO: Terminal split should be implemented differently from non-terminal split.
- // TODO: Pick num splits based on |limit|.
- protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1)
-
- override def outputOrdering: Seq[SortOrder] = sortOrder
-
- override def simpleString: String = {
- val orderByString = sortOrder.mkString("[", ",", "]")
- val outputString = output.mkString("[", ",", "]")
-
- s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)"
- }
-}
-
/**
* Return a new RDD that has exactly `numPartitions` partitions.
* Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g.
@@ -410,18 +334,6 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode {
}
}
-/**
- * Returns the rows in left that also appear in right using the built in spark
- * intersection function.
- */
-case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode {
- override def output: Seq[Attribute] = children.head.output
-
- protected override def doExecute(): RDD[InternalRow] = {
- left.execute().map(_.copy()).intersection(right.execute().map(_.copy()))
- }
-}
-
/**
* A plan node that does nothing but lie about the output of its child. Used to spice a
* (hopefully structurally equivalent) tree from a different optimization sequence into an already
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 3cfa3dfd9c7ec..c6adb583f931b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -404,7 +404,17 @@ case class DescribeFunction(
result
}
- case None => Seq(Row(s"Function: $functionName is not found."))
+ case None => Seq(Row(s"Function: $functionName not found."))
}
}
}
+
+case class SetDatabaseCommand(databaseName: String) extends RunnableCommand {
+
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ sqlContext.catalog.setCurrentDatabase(databaseName)
+ Seq.empty[Row]
+ }
+
+ override val output: Seq[Attribute] = Seq.empty
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala
deleted file mode 100644
index f4766b037027d..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala
+++ /dev/null
@@ -1,193 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements. See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.spark.sql.execution.datasources
-
-import scala.language.implicitConversions
-import scala.util.matching.Regex
-
-import org.apache.spark.Logging
-import org.apache.spark.sql.SaveMode
-import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserInterface, TableIdentifier}
-import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.util.DataTypeParser
-import org.apache.spark.sql.types._
-
-/**
- * A parser for foreign DDL commands.
- */
-class DDLParser(fallback: => ParserInterface)
- extends AbstractSparkSQLParser with DataTypeParser with Logging {
-
- override def parseExpression(sql: String): Expression = fallback.parseExpression(sql)
-
- override def parseTableIdentifier(sql: String): TableIdentifier = {
- fallback.parseTableIdentifier(sql)
- }
-
- def parse(input: String, exceptionOnError: Boolean): LogicalPlan = {
- try {
- parsePlan(input)
- } catch {
- case ddlException: DDLException => throw ddlException
- case _ if !exceptionOnError => fallback.parsePlan(input)
- case x: Throwable => throw x
- }
- }
-
- // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
- // properties via reflection the class in runtime for constructing the SqlLexical object
- protected val CREATE = Keyword("CREATE")
- protected val TEMPORARY = Keyword("TEMPORARY")
- protected val TABLE = Keyword("TABLE")
- protected val IF = Keyword("IF")
- protected val NOT = Keyword("NOT")
- protected val EXISTS = Keyword("EXISTS")
- protected val USING = Keyword("USING")
- protected val OPTIONS = Keyword("OPTIONS")
- protected val DESCRIBE = Keyword("DESCRIBE")
- protected val EXTENDED = Keyword("EXTENDED")
- protected val AS = Keyword("AS")
- protected val COMMENT = Keyword("COMMENT")
- protected val REFRESH = Keyword("REFRESH")
-
- protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable
-
- protected def start: Parser[LogicalPlan] = ddl
-
- /**
- * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable
- * USING org.apache.spark.sql.avro
- * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
- * or
- * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable(intField int, stringField string...)
- * USING org.apache.spark.sql.avro
- * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
- * or
- * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable
- * USING org.apache.spark.sql.avro
- * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
- * AS SELECT ...
- */
- protected lazy val createTable: Parser[LogicalPlan] = {
- // TODO: Support database.table.
- (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ tableIdentifier ~
- tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ {
- case temp ~ allowExisting ~ tableIdent ~ columns ~ provider ~ opts ~ query =>
- if (temp.isDefined && allowExisting.isDefined) {
- throw new DDLException(
- "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.")
- }
-
- val options = opts.getOrElse(Map.empty[String, String])
- if (query.isDefined) {
- if (columns.isDefined) {
- throw new DDLException(
- "a CREATE TABLE AS SELECT statement does not allow column definitions.")
- }
- // When IF NOT EXISTS clause appears in the query, the save mode will be ignore.
- val mode = if (allowExisting.isDefined) {
- SaveMode.Ignore
- } else if (temp.isDefined) {
- SaveMode.Overwrite
- } else {
- SaveMode.ErrorIfExists
- }
-
- val queryPlan = fallback.parsePlan(query.get)
- CreateTableUsingAsSelect(tableIdent,
- provider,
- temp.isDefined,
- Array.empty[String],
- bucketSpec = None,
- mode,
- options,
- queryPlan)
- } else {
- val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields)))
- CreateTableUsing(
- tableIdent,
- userSpecifiedSchema,
- provider,
- temp.isDefined,
- options,
- allowExisting.isDefined,
- managedIfNoPath = false)
- }
- }
- }
-
- // This is the same as tableIdentifier in SqlParser.
- protected lazy val tableIdentifier: Parser[TableIdentifier] =
- (ident <~ ".").? ~ ident ^^ {
- case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName)
- }
-
- protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")"
-
- /*
- * describe [extended] table avroTable
- * This will display all columns of table `avroTable` includes column_name,column_type,comment
- */
- protected lazy val describeTable: Parser[LogicalPlan] =
- (DESCRIBE ~> opt(EXTENDED)) ~ tableIdentifier ^^ {
- case e ~ tableIdent =>
- DescribeCommand(UnresolvedRelation(tableIdent, None), e.isDefined)
- }
-
- protected lazy val refreshTable: Parser[LogicalPlan] =
- REFRESH ~> TABLE ~> tableIdentifier ^^ {
- case tableIndet =>
- RefreshTable(tableIndet)
- }
-
- protected lazy val options: Parser[Map[String, String]] =
- "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap }
-
- protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")}
-
- override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch(
- s"identifier matching regex $regex", {
- case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str
- case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str
- }
- )
-
- protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ {
- case name => name
- }
-
- protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ {
- case parts => parts.mkString(".")
- }
-
- protected lazy val pair: Parser[(String, String)] =
- optionName ~ stringLit ^^ { case k ~ v => (k, v) }
-
- protected lazy val column: Parser[StructField] =
- ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm =>
- val meta = cm match {
- case Some(comment) =>
- new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build()
- case None => Metadata.empty
- }
-
- StructField(columnName, typ, nullable = true, meta)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index da9320ffb61c3..c24967abeb33e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -29,12 +29,14 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.util.collection.BitSet
/**
* A Strategy for planning scans over data sources defined using the sources API.
@@ -97,10 +99,15 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
(partitionAndNormalColumnAttrs ++ projects).toSeq
}
+ // Prune the buckets based on the pushed filters that do not contain partitioning key
+ // since the bucketing key is not allowed to use the columns in partitioning key
+ val bucketSet = getBuckets(pushedFilters, t.getBucketSpec)
+
val scan = buildPartitionedTableScan(
l,
partitionAndNormalColumnProjs,
pushedFilters,
+ bucketSet,
t.partitionSpec.partitionColumns,
selectedPartitions)
@@ -124,11 +131,14 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
val sharedHadoopConf = SparkHadoopUtil.get.conf
val confBroadcast =
t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf))
+ // Prune the buckets based on the filters
+ val bucketSet = getBuckets(filters, t.getBucketSpec)
pruneFilterProject(
l,
projects,
filters,
- (a, f) => t.buildInternalScan(a.map(_.name).toArray, f, t.paths, confBroadcast)) :: Nil
+ (a, f) =>
+ t.buildInternalScan(a.map(_.name).toArray, f, bucketSet, t.paths, confBroadcast)) :: Nil
case l @ LogicalRelation(baseRelation: TableScan, _, _) =>
execution.PhysicalRDD.createFromDataSource(
@@ -150,6 +160,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
logicalRelation: LogicalRelation,
projections: Seq[NamedExpression],
filters: Seq[Expression],
+ buckets: Option[BitSet],
partitionColumns: StructType,
partitions: Array[Partition]): SparkPlan = {
val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation]
@@ -174,7 +185,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
// assuming partition columns data stored in data files are always consistent with those
// partition values encoded in partition directory paths.
val dataRows = relation.buildInternalScan(
- requiredDataColumns.map(_.name).toArray, filters, Array(dir), confBroadcast)
+ requiredDataColumns.map(_.name).toArray, filters, buckets, Array(dir), confBroadcast)
// Merges data values with partition values.
mergeWithPartitionValues(
@@ -251,6 +262,69 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
}
}
+ // Get the bucket ID based on the bucketing values.
+ // Restriction: Bucket pruning works iff the bucketing column has one and only one column.
+ def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
+ val mutableRow = new SpecificMutableRow(Seq(bucketColumn.dataType))
+ mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null)
+ val bucketIdGeneration = UnsafeProjection.create(
+ HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil,
+ bucketColumn :: Nil)
+
+ bucketIdGeneration(mutableRow).getInt(0)
+ }
+
+ // Get the bucket BitSet by reading the filters that only contains bucketing keys.
+ // Note: When the returned BitSet is None, no pruning is possible.
+ // Restriction: Bucket pruning works iff the bucketing column has one and only one column.
+ private def getBuckets(
+ filters: Seq[Expression],
+ bucketSpec: Option[BucketSpec]): Option[BitSet] = {
+
+ if (bucketSpec.isEmpty ||
+ bucketSpec.get.numBuckets == 1 ||
+ bucketSpec.get.bucketColumnNames.length != 1) {
+ // None means all the buckets need to be scanned
+ return None
+ }
+
+ // Just get the first because bucketing pruning only works when the column has one column
+ val bucketColumnName = bucketSpec.get.bucketColumnNames.head
+ val numBuckets = bucketSpec.get.numBuckets
+ val matchedBuckets = new BitSet(numBuckets)
+ matchedBuckets.clear()
+
+ filters.foreach {
+ case expressions.EqualTo(a: Attribute, Literal(v, _)) if a.name == bucketColumnName =>
+ matchedBuckets.set(getBucketId(a, numBuckets, v))
+ case expressions.EqualTo(Literal(v, _), a: Attribute) if a.name == bucketColumnName =>
+ matchedBuckets.set(getBucketId(a, numBuckets, v))
+ case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) if a.name == bucketColumnName =>
+ matchedBuckets.set(getBucketId(a, numBuckets, v))
+ case expressions.EqualNullSafe(Literal(v, _), a: Attribute) if a.name == bucketColumnName =>
+ matchedBuckets.set(getBucketId(a, numBuckets, v))
+ // Because we only convert In to InSet in Optimizer when there are more than certain
+ // items. So it is possible we still get an In expression here that needs to be pushed
+ // down.
+ case expressions.In(a: Attribute, list)
+ if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName =>
+ val hSet = list.map(e => e.eval(EmptyRow))
+ hSet.foreach(e => matchedBuckets.set(getBucketId(a, numBuckets, e)))
+ case expressions.IsNull(a: Attribute) if a.name == bucketColumnName =>
+ matchedBuckets.set(getBucketId(a, numBuckets, null))
+ case _ =>
+ }
+
+ logInfo {
+ val selected = matchedBuckets.cardinality()
+ val percentPruned = (1 - selected.toDouble / numBuckets.toDouble) * 100
+ s"Selected $selected buckets out of $numBuckets, pruned $percentPruned% partitions."
+ }
+
+ // None means all the buckets need to be scanned
+ if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets)
+ }
+
protected def prunePartitions(
predicates: Seq[Expression],
partitionSpec: PartitionSpec): Seq[Partition] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
index cc8dcf59307f2..7702f535ad2f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
@@ -29,11 +29,11 @@ import org.apache.hadoop.util.StringUtils
import org.apache.spark.Logging
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
+import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{CalendarIntervalType, StructType}
import org.apache.spark.util.Utils
-
case class ResolvedDataSource(provider: Class[_], relation: BaseRelation)
@@ -92,6 +92,36 @@ object ResolvedDataSource extends Logging {
}
}
+ def createSource(
+ sqlContext: SQLContext,
+ userSpecifiedSchema: Option[StructType],
+ providerName: String,
+ options: Map[String, String]): Source = {
+ val provider = lookupDataSource(providerName).newInstance() match {
+ case s: StreamSourceProvider => s
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Data source $providerName does not support streamed reading")
+ }
+
+ provider.createSource(sqlContext, options, userSpecifiedSchema)
+ }
+
+ def createSink(
+ sqlContext: SQLContext,
+ providerName: String,
+ options: Map[String, String],
+ partitionColumns: Seq[String]): Sink = {
+ val provider = lookupDataSource(providerName).newInstance() match {
+ case s: StreamSinkProvider => s
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Data source $providerName does not support streamed writing")
+ }
+
+ provider.createSink(sqlContext, options, partitionColumns)
+ }
+
/** Create a [[ResolvedDataSource]] for reading data in. */
def apply(
sqlContext: SQLContext,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
index edd87c2d8ed07..3605150b3b767 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
@@ -127,6 +127,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
val conf = getConf(isDriverSide = false)
val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop)
+ val existingBytesRead = inputMetrics.bytesRead
// Sets the thread local variable for the file's name
split.serializableHadoopSplit.value match {
@@ -142,9 +143,13 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
case _ => None
}
+ // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics.
+ // If we do a coalesce, however, we are likely to compute multiple partitions in the same
+ // task and in the same thread, in which case we need to avoid override values written by
+ // previous partitions (SPARK-13071).
def updateBytesRead(): Unit = {
getBytesReadCallback.foreach { getBytesRead =>
- inputMetrics.setBytesRead(getBytesRead())
+ inputMetrics.setBytesRead(existingBytesRead + getBytesRead())
}
}
@@ -209,7 +214,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
}
havePair = false
if (!finished) {
- inputMetrics.incRecordsRead(1)
+ inputMetrics.incRecordsReadInternal(1)
}
if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
updateBytesRead()
@@ -241,7 +246,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
// If we can't get the bytes read from the FS stats, fall back to the split size,
// which may be inaccurate.
try {
- inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength)
+ inputMetrics.incBytesReadInternal(split.serializableHadoopSplit.value.getLength)
} catch {
case e: java.io.IOException =>
logWarning("Unable to get input size to set InputMetrics for task", e)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
index c3603936dfd2e..a141b58d3d72c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
-import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.types._
@@ -32,7 +33,7 @@ import org.apache.spark.sql.types._
*/
case class DescribeCommand(
table: LogicalPlan,
- isExtended: Boolean) extends LogicalPlan with Command {
+ isExtended: Boolean) extends LogicalPlan with logical.Command {
override def children: Seq[LogicalPlan] = Seq.empty
@@ -59,7 +60,7 @@ case class CreateTableUsing(
temporary: Boolean,
options: Map[String, String],
allowExisting: Boolean,
- managedIfNoPath: Boolean) extends LogicalPlan with Command {
+ managedIfNoPath: Boolean) extends LogicalPlan with logical.Command {
override def output: Seq[Attribute] = Seq.empty
override def children: Seq[LogicalPlan] = Seq.empty
@@ -67,8 +68,8 @@ case class CreateTableUsing(
/**
* A node used to support CTAS statements and saveAsTable for the data source API.
- * This node is a [[UnaryNode]] instead of a [[Command]] because we want the analyzer
- * can analyze the logical plan that will be used to populate the table.
+ * This node is a [[logical.UnaryNode]] instead of a [[logical.Command]] because we want the
+ * analyzer can analyze the logical plan that will be used to populate the table.
* So, [[PreWriteCheck]] can detect cases that are not allowed.
*/
case class CreateTableUsingAsSelect(
@@ -79,7 +80,7 @@ case class CreateTableUsingAsSelect(
bucketSpec: Option[BucketSpec],
mode: SaveMode,
options: Map[String, String],
- child: LogicalPlan) extends UnaryNode {
+ child: LogicalPlan) extends logical.UnaryNode {
override def output: Seq[Attribute] = Seq.empty[Attribute]
}
@@ -169,8 +170,3 @@ class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String]
override def -(key: String): Map[String, String] = baseMap - key.toLowerCase
}
-
-/**
- * The exception thrown from the DDL parser.
- */
-class DDLException(message: String) extends RuntimeException(message)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
index 44d5e4ff7ec8b..8b773ddfcb656 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
@@ -134,8 +134,12 @@ private[json] object InferSchema {
val v = parser.getDecimalValue
DecimalType(v.precision(), v.scale())
case FLOAT | DOUBLE =>
- // TODO(davies): Should we use decimal if possible?
- DoubleType
+ if (configOptions.floatAsBigDecimal) {
+ val v = parser.getDecimalValue
+ DecimalType(v.precision(), v.scale())
+ } else {
+ DoubleType
+ }
}
case VALUE_TRUE | VALUE_FALSE => BooleanType
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala
index fe5b20697e40e..31a95ed461215 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala
@@ -34,6 +34,8 @@ private[sql] class JSONOptions(
parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
val primitivesAsString =
parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false)
+ val floatAsBigDecimal =
+ parameters.get("floatAsBigDecimal").map(_.toBoolean).getOrElse(false)
val allowComments =
parameters.get("allowComments").map(_.toBoolean).getOrElse(false)
val allowUnquotedFieldNames =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
index e9b734b0abf50..5a5cb5cf03d4a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
@@ -207,11 +207,26 @@ private[sql] object ParquetFilters {
*/
}
+ /**
+ * SPARK-11955: The optional fields will have metadata StructType.metadataKeyForOptionalField.
+ * These fields only exist in one side of merged schemas. Due to that, we can't push down filters
+ * using such fields, otherwise Parquet library will throw exception. Here we filter out such
+ * fields.
+ */
+ private def getFieldMap(dataType: DataType): Array[(String, DataType)] = dataType match {
+ case StructType(fields) =>
+ fields.filter { f =>
+ !f.metadata.contains(StructType.metadataKeyForOptionalField) ||
+ !f.metadata.getBoolean(StructType.metadataKeyForOptionalField)
+ }.map(f => f.name -> f.dataType) ++ fields.flatMap { f => getFieldMap(f.dataType) }
+ case _ => Array.empty[(String, DataType)]
+ }
+
/**
* Converts data sources filters to Parquet filter predicates.
*/
def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = {
- val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap
+ val dataTypeOf = getFieldMap(schema).toMap
relaxParquetValidTypeMap
@@ -231,29 +246,29 @@ private[sql] object ParquetFilters {
// Probably I missed something and obviously this should be changed.
predicate match {
- case sources.IsNull(name) =>
+ case sources.IsNull(name) if dataTypeOf.contains(name) =>
makeEq.lift(dataTypeOf(name)).map(_(name, null))
- case sources.IsNotNull(name) =>
+ case sources.IsNotNull(name) if dataTypeOf.contains(name) =>
makeNotEq.lift(dataTypeOf(name)).map(_(name, null))
- case sources.EqualTo(name, value) =>
+ case sources.EqualTo(name, value) if dataTypeOf.contains(name) =>
makeEq.lift(dataTypeOf(name)).map(_(name, value))
- case sources.Not(sources.EqualTo(name, value)) =>
+ case sources.Not(sources.EqualTo(name, value)) if dataTypeOf.contains(name) =>
makeNotEq.lift(dataTypeOf(name)).map(_(name, value))
- case sources.EqualNullSafe(name, value) =>
+ case sources.EqualNullSafe(name, value) if dataTypeOf.contains(name) =>
makeEq.lift(dataTypeOf(name)).map(_(name, value))
- case sources.Not(sources.EqualNullSafe(name, value)) =>
+ case sources.Not(sources.EqualNullSafe(name, value)) if dataTypeOf.contains(name) =>
makeNotEq.lift(dataTypeOf(name)).map(_(name, value))
- case sources.LessThan(name, value) =>
+ case sources.LessThan(name, value) if dataTypeOf.contains(name) =>
makeLt.lift(dataTypeOf(name)).map(_(name, value))
- case sources.LessThanOrEqual(name, value) =>
+ case sources.LessThanOrEqual(name, value) if dataTypeOf.contains(name) =>
makeLtEq.lift(dataTypeOf(name)).map(_(name, value))
- case sources.GreaterThan(name, value) =>
+ case sources.GreaterThan(name, value) if dataTypeOf.contains(name) =>
makeGt.lift(dataTypeOf(name)).map(_(name, value))
- case sources.GreaterThanOrEqual(name, value) =>
+ case sources.GreaterThanOrEqual(name, value) if dataTypeOf.contains(name) =>
makeGtEq.lift(dataTypeOf(name)).map(_(name, value))
case sources.In(name, valueSet) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index b460ec1d26047..1e686d41f41db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -258,7 +258,12 @@ private[sql] class ParquetRelation(
job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]])
ParquetOutputFormat.setWriteSupportClass(job, classOf[CatalystWriteSupport])
- CatalystWriteSupport.setSchema(dataSchema, conf)
+
+ // We want to clear this temporary metadata from saving into Parquet file.
+ // This metadata is only useful for detecting optional columns when pushdowning filters.
+ val dataSchemaToWrite = StructType.removeMetadata(StructType.metadataKeyForOptionalField,
+ dataSchema).asInstanceOf[StructType]
+ CatalystWriteSupport.setSchema(dataSchemaToWrite, conf)
// Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema)
// and `CatalystWriteSupport` (writing actual rows to Parquet files).
@@ -304,10 +309,6 @@ private[sql] class ParquetRelation(
val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString
val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp
- // When merging schemas is enabled and the column of the given filter does not exist,
- // Parquet emits an exception which is an issue of Parquet (PARQUET-389).
- val safeParquetFilterPushDown = !shouldMergeSchemas && parquetFilterPushDown
-
// Parquet row group size. We will use this value as the value for
// mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value
// of these flags are smaller than the parquet row group size.
@@ -321,7 +322,7 @@ private[sql] class ParquetRelation(
dataSchema,
parquetBlockSize,
useMetadataCache,
- safeParquetFilterPushDown,
+ parquetFilterPushDown,
assumeBinaryIsString,
assumeInt96IsTimestamp) _
@@ -799,12 +800,37 @@ private[sql] object ParquetRelation extends Logging {
assumeInt96IsTimestamp = assumeInt96IsTimestamp,
writeLegacyParquetFormat = writeLegacyParquetFormat)
- footers.map { footer =>
- ParquetRelation.readSchemaFromFooter(footer, converter)
- }.reduceLeftOption(_ merge _).iterator
+ if (footers.isEmpty) {
+ Iterator.empty
+ } else {
+ var mergedSchema = ParquetRelation.readSchemaFromFooter(footers.head, converter)
+ footers.tail.foreach { footer =>
+ val schema = ParquetRelation.readSchemaFromFooter(footer, converter)
+ try {
+ mergedSchema = mergedSchema.merge(schema)
+ } catch { case cause: SparkException =>
+ throw new SparkException(
+ s"Failed merging schema of file ${footer.getFile}:\n${schema.treeString}", cause)
+ }
+ }
+ Iterator.single(mergedSchema)
+ }
}.collect()
- partiallyMergedSchemas.reduceLeftOption(_ merge _)
+ if (partiallyMergedSchemas.isEmpty) {
+ None
+ } else {
+ var finalSchema = partiallyMergedSchemas.head
+ partiallyMergedSchemas.tail.foreach { schema =>
+ try {
+ finalSchema = finalSchema.merge(schema)
+ } catch { case cause: SparkException =>
+ throw new SparkException(
+ s"Failed merging schema:\n${schema.treeString}", cause)
+ }
+ }
+ Some(finalSchema)
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 04640711d99d0..943ad31c0cef5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -20,14 +20,17 @@ package org.apache.spark.sql.execution.joins
import scala.concurrent._
import scala.concurrent.duration._
-import org.apache.spark.{InternalAccumulator, TaskContext}
+import org.apache.spark.TaskContext
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution}
+import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.ThreadUtils
+import org.apache.spark.util.collection.CompactBuffer
/**
* Performs an inner hash join of two child relations. When the output RDD of this operator is
@@ -42,7 +45,7 @@ case class BroadcastHashJoin(
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan)
- extends BinaryNode with HashJoin {
+ extends BinaryNode with HashJoin with CodegenSupport {
override private[sql] lazy val metrics = Map(
"numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
@@ -74,7 +77,7 @@ case class BroadcastHashJoin(
// broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
- future {
+ Future {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sparkContext, executionId) {
@@ -117,6 +120,87 @@ case class BroadcastHashJoin(
hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows)
}
}
+
+ // the term for hash relation
+ private var relationTerm: String = _
+
+ override def upstream(): RDD[InternalRow] = {
+ streamedPlan.asInstanceOf[CodegenSupport].upstream()
+ }
+
+ override def doProduce(ctx: CodegenContext): String = {
+ // create a name for HashRelation
+ val broadcastRelation = Await.result(broadcastFuture, timeout)
+ val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
+ relationTerm = ctx.freshName("relation")
+ // TODO: create specialized HashRelation for single join key
+ val clsName = classOf[UnsafeHashedRelation].getName
+ ctx.addMutableState(clsName, relationTerm,
+ s"""
+ | $relationTerm = ($clsName) $broadcast.value();
+ | incPeakExecutionMemory($relationTerm.getUnsafeSize());
+ """.stripMargin)
+
+ s"""
+ | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)}
+ """.stripMargin
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ // generate the key as UnsafeRow
+ ctx.currentVars = input
+ val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
+ val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr)
+ val keyTerm = keyVal.value
+ val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false"
+
+ // find the matches from HashedRelation
+ val matches = ctx.freshName("matches")
+ val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
+ val i = ctx.freshName("i")
+ val size = ctx.freshName("size")
+ val row = ctx.freshName("row")
+
+ // create variables for output
+ ctx.currentVars = null
+ ctx.INPUT_ROW = row
+ val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
+ BoundReference(i, a.dataType, a.nullable).gen(ctx)
+ }
+ val resultVars = buildSide match {
+ case BuildLeft => buildColumns ++ input
+ case BuildRight => input ++ buildColumns
+ }
+
+ val ouputCode = if (condition.isDefined) {
+ // filter the output via condition
+ ctx.currentVars = resultVars
+ val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
+ s"""
+ | ${ev.code}
+ | if (!${ev.isNull} && ${ev.value}) {
+ | ${consume(ctx, resultVars)}
+ | }
+ """.stripMargin
+ } else {
+ consume(ctx, resultVars)
+ }
+
+ s"""
+ | // generate join key
+ | ${keyVal.code}
+ | // find matches from HashRelation
+ | $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm);
+ | if ($matches != null) {
+ | int $size = $matches.size();
+ | for (int $i = 0; $i < $size; $i++) {
+ | UnsafeRow $row = (UnsafeRow) $matches.apply($i);
+ | ${buildColumns.map(_.code).mkString("\n")}
+ | $ouputCode
+ | }
+ | }
+ """.stripMargin
+ }
}
object BroadcastHashJoin {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index db8edd169dcfa..f48fc3b84864d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -76,7 +76,7 @@ case class BroadcastHashOuterJoin(
// broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
- future {
+ Future {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sparkContext, executionId) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
new file mode 100644
index 0000000000000..256f4228ae99e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical._
+
+
+/**
+ * Take the first `limit` elements and collect them to a single partition.
+ *
+ * This operator will be used when a logical `Limit` operation is the final operator in an
+ * logical plan, which happens when the user is collecting results back to the driver.
+ */
+case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+ override def outputPartitioning: Partitioning = SinglePartition
+ override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
+ private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
+ protected override def doExecute(): RDD[InternalRow] = {
+ val shuffled = new ShuffledRowRDD(
+ Exchange.prepareShuffleDependency(child.execute(), child.output, SinglePartition, serializer))
+ shuffled.mapPartitionsInternal(_.take(limit))
+ }
+}
+
+/**
+ * Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]].
+ */
+trait BaseLimit extends UnaryNode {
+ val limit: Int
+ override def output: Seq[Attribute] = child.output
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+ protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
+ iter.take(limit)
+ }
+}
+
+/**
+ * Take the first `limit` elements of each child partition, but do not collect or shuffle them.
+ */
+case class LocalLimit(limit: Int, child: SparkPlan) extends BaseLimit {
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+}
+
+/**
+ * Take the first `limit` elements of the child's single output partition.
+ */
+case class GlobalLimit(limit: Int, child: SparkPlan) extends BaseLimit {
+ override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil
+}
+
+/**
+ * Take the first limit elements as defined by the sortOrder, and do projection if needed.
+ * This is logically equivalent to having a Limit operator after a [[Sort]] operator,
+ * or having a [[Project]] operator between them.
+ * This could have been named TopK, but Spark's top operator does the opposite in ordering
+ * so we name it TakeOrdered to avoid confusion.
+ */
+case class TakeOrderedAndProject(
+ limit: Int,
+ sortOrder: Seq[SortOrder],
+ projectList: Option[Seq[NamedExpression]],
+ child: SparkPlan) extends UnaryNode {
+
+ override def output: Seq[Attribute] = {
+ val projectOutput = projectList.map(_.map(_.toAttribute))
+ projectOutput.getOrElse(child.output)
+ }
+
+ override def outputPartitioning: Partitioning = SinglePartition
+
+ // We need to use an interpreted ordering here because generated orderings cannot be serialized
+ // and this ordering needs to be created on the driver in order to be passed into Spark core code.
+ private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output)
+
+ private def collectData(): Array[InternalRow] = {
+ val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
+ if (projectList.isDefined) {
+ val proj = UnsafeProjection.create(projectList.get, child.output)
+ data.map(r => proj(r).copy())
+ } else {
+ data
+ }
+ }
+
+ override def executeCollect(): Array[InternalRow] = {
+ collectData()
+ }
+
+ // TODO: Terminal split should be implemented differently from non-terminal split.
+ // TODO: Pick num splits based on |limit|.
+ protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1)
+
+ override def outputOrdering: Seq[SortOrder] = sortOrder
+
+ override def simpleString: String = {
+ val orderByString = sortOrder.mkString("[", ",", "]")
+ val outputString = output.mkString("[", ",", "]")
+
+ s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)"
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
index a0dfe996ccd55..8726e4878106d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.execution.local
-import scala.util.control.NonFatal
-
import org.apache.spark.Logging
import org.apache.spark.sql.{Row, SQLConf}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
@@ -96,33 +94,13 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin
inputSchema: Seq[Attribute]): () => MutableProjection = {
log.debug(
s"Creating MutableProj: $expressions, inputSchema: $inputSchema")
- try {
- GenerateMutableProjection.generate(expressions, inputSchema)
- } catch {
- case NonFatal(e) =>
- if (isTesting) {
- throw e
- } else {
- log.error("Failed to generate mutable projection, fallback to interpreted", e)
- () => new InterpretedMutableProjection(expressions, inputSchema)
- }
- }
+ GenerateMutableProjection.generate(expressions, inputSchema)
}
protected def newPredicate(
expression: Expression,
inputSchema: Seq[Attribute]): (InternalRow) => Boolean = {
- try {
- GeneratePredicate.generate(expression, inputSchema)
- } catch {
- case NonFatal(e) =>
- if (isTesting) {
- throw e
- } else {
- log.error("Failed to generate predicate, fallback to interpreted", e)
- InterpretedPredicate.create(expression, inputSchema)
- }
- }
+ GeneratePredicate.generate(expression, inputSchema)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 950dc7816241f..6b43d273fefde 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.metric
import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, SparkContext}
+import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.util.Utils
/**
@@ -27,9 +28,16 @@ import org.apache.spark.util.Utils
* An implementation of SQLMetric should override `+=` and `add` to avoid boxing.
*/
private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
- name: String, val param: SQLMetricParam[R, T])
+ name: String,
+ val param: SQLMetricParam[R, T])
extends Accumulable[R, T](param.zero, param, Some(name), internal = true) {
+ // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later
+ override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
+ new AccumulableInfo(id, Some(name), update, value, isInternal, countFailedValues,
+ Some(SQLMetrics.ACCUM_IDENTIFIER))
+ }
+
def reset(): Unit = {
this.value = param.zero
}
@@ -73,6 +81,14 @@ private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetr
// Although there is a boxing here, it's fine because it's only called in SQLListener
override def value: Long = _value
+
+ // Needed for SQLListenerSuite
+ override def equals(other: Any): Boolean = {
+ other match {
+ case o: LongSQLMetricValue => value == o.value
+ case _ => false
+ }
+ }
}
/**
@@ -126,6 +142,9 @@ private object StaticsLongSQLMetricParam extends LongSQLMetricParam(
private[sql] object SQLMetrics {
+ // Identifier for distinguishing SQL metrics from other accumulators
+ private[sql] val ACCUM_IDENTIFIER = "sql"
+
private def createLongMetric(
sc: SparkContext,
name: String,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 2acca1743cbb9..582dda8603f4e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -53,14 +53,14 @@ trait ObjectOperator extends SparkPlan {
*/
case class MapPartitions(
func: Iterator[Any] => Iterator[Any],
- input: Expression,
+ deserializer: Expression,
serializer: Seq[NamedExpression],
child: SparkPlan) extends UnaryNode with ObjectOperator {
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
- val getObject = generateToObject(input, child.output)
+ val getObject = generateToObject(deserializer, child.output)
val outputObject = generateToRow(serializer)
func(iter.map(getObject)).map(outputObject)
}
@@ -72,7 +72,7 @@ case class MapPartitions(
*/
case class AppendColumns(
func: Any => Any,
- input: Expression,
+ deserializer: Expression,
serializer: Seq[NamedExpression],
child: SparkPlan) extends UnaryNode with ObjectOperator {
@@ -82,7 +82,7 @@ case class AppendColumns(
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
- val getObject = generateToObject(input, child.output)
+ val getObject = generateToObject(deserializer, child.output)
val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema)
val outputObject = generateToRow(serializer)
@@ -103,10 +103,11 @@ case class AppendColumns(
*/
case class MapGroups(
func: (Any, Iterator[Any]) => TraversableOnce[Any],
- keyObject: Expression,
- input: Expression,
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
serializer: Seq[NamedExpression],
groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
child: SparkPlan) extends UnaryNode with ObjectOperator {
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
@@ -121,8 +122,8 @@ case class MapGroups(
child.execute().mapPartitionsInternal { iter =>
val grouped = GroupedIterator(iter, groupingAttributes, child.output)
- val getKey = generateToObject(keyObject, groupingAttributes)
- val getValue = generateToObject(input, child.output)
+ val getKey = generateToObject(keyDeserializer, groupingAttributes)
+ val getValue = generateToObject(valueDeserializer, dataAttributes)
val outputObject = generateToRow(serializer)
grouped.flatMap { case (key, rowIter) =>
@@ -142,12 +143,14 @@ case class MapGroups(
*/
case class CoGroup(
func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any],
- keyObject: Expression,
- leftObject: Expression,
- rightObject: Expression,
+ keyDeserializer: Expression,
+ leftDeserializer: Expression,
+ rightDeserializer: Expression,
serializer: Seq[NamedExpression],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
+ leftAttr: Seq[Attribute],
+ rightAttr: Seq[Attribute],
left: SparkPlan,
right: SparkPlan) extends BinaryNode with ObjectOperator {
@@ -164,9 +167,9 @@ case class CoGroup(
val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
- val getKey = generateToObject(keyObject, leftGroup)
- val getLeft = generateToObject(leftObject, left.output)
- val getRight = generateToObject(rightObject, right.output)
+ val getKey = generateToObject(keyDeserializer, leftGroup)
+ val getLeft = generateToObject(leftDeserializer, leftAttr)
+ val getRight = generateToObject(rightDeserializer, rightAttr)
val outputObject = generateToRow(serializer)
new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
index e3a016e18db87..bf62bb05c3d93 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
@@ -143,7 +143,7 @@ object EvaluatePython {
values(i) = toJava(row.get(i, struct.fields(i).dataType), struct.fields(i).dataType)
i += 1
}
- new GenericInternalRowWithSchema(values, struct)
+ new GenericRowWithSchema(values, struct)
case (a: ArrayData, array: ArrayType) =>
val values = new java.util.ArrayList[Any](a.numElements())
@@ -199,10 +199,7 @@ object EvaluatePython {
case (c: Long, TimestampType) => c
- case (c: String, StringType) => UTF8String.fromString(c)
- case (c, StringType) =>
- // If we get here, c is not a string. Call toString on it.
- UTF8String.fromString(c.toString)
+ case (c, StringType) => UTF8String.fromString(c.toString)
case (c: String, BinaryType) => c.getBytes("utf-8")
case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c
@@ -263,11 +260,11 @@ object EvaluatePython {
}
/**
- * Pickler for InternalRow
+ * Pickler for external row.
*/
private class RowPickler extends IObjectPickler {
- private val cls = classOf[GenericInternalRowWithSchema]
+ private val cls = classOf[GenericRowWithSchema]
// register this to Pickler and Unpickler
def register(): Unit = {
@@ -282,7 +279,7 @@ object EvaluatePython {
} else {
// it will be memorized by Pickler to save some bytes
pickler.save(this)
- val row = obj.asInstanceOf[GenericInternalRowWithSchema]
+ val row = obj.asInstanceOf[GenericRowWithSchema]
// schema should always be same object for memoization
pickler.save(row.schema)
out.write(Opcodes.TUPLE1)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala
new file mode 100644
index 0000000000000..1f25eb8fc5223
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.sql.DataFrame
+
+/**
+ * Used to pass a batch of data through a streaming query execution along with an indication
+ * of progress in the stream.
+ */
+class Batch(val end: Offset, val data: DataFrame)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala
new file mode 100644
index 0000000000000..d2cb20ef8b819
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import scala.util.Try
+
+/**
+ * An ordered collection of offsets, used to track the progress of processing data from one or more
+ * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance
+ * vector clock that must progress linearly forward.
+ */
+case class CompositeOffset(offsets: Seq[Option[Offset]]) extends Offset {
+ /**
+ * Returns a negative integer, zero, or a positive integer as this object is less than, equal to,
+ * or greater than the specified object.
+ */
+ override def compareTo(other: Offset): Int = other match {
+ case otherComposite: CompositeOffset if otherComposite.offsets.size == offsets.size =>
+ val comparisons = offsets.zip(otherComposite.offsets).map {
+ case (Some(a), Some(b)) => a compareTo b
+ case (None, None) => 0
+ case (None, _) => -1
+ case (_, None) => 1
+ }
+ val nonZeroSigns = comparisons.map(sign).filter(_ != 0).toSet
+ nonZeroSigns.size match {
+ case 0 => 0 // if both empty or only 0s
+ case 1 => nonZeroSigns.head // if there are only (0s and 1s) or (0s and -1s)
+ case _ => // there are both 1s and -1s
+ throw new IllegalArgumentException(
+ s"Invalid comparison between non-linear histories: $this <=> $other")
+ }
+ case _ =>
+ throw new IllegalArgumentException(s"Cannot compare $this <=> $other")
+ }
+
+ private def sign(num: Int): Int = num match {
+ case i if i < 0 => -1
+ case i if i == 0 => 0
+ case i if i > 0 => 1
+ }
+}
+
+object CompositeOffset {
+ /**
+ * Returns a [[CompositeOffset]] with a variable sequence of offsets.
+ * `nulls` in the sequence are converted to `None`s.
+ */
+ def fill(offsets: Offset*): CompositeOffset = {
+ CompositeOffset(offsets.map(Option(_)))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala
new file mode 100644
index 0000000000000..008195af38b75
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+/**
+ * A simple offset for sources that produce a single linear stream of data.
+ */
+case class LongOffset(offset: Long) extends Offset {
+
+ override def compareTo(other: Offset): Int = other match {
+ case l: LongOffset => offset.compareTo(l.offset)
+ case _ =>
+ throw new IllegalArgumentException(s"Invalid comparison of $getClass with ${other.getClass}")
+ }
+
+ def +(increment: Long): LongOffset = new LongOffset(offset + increment)
+ def -(decrement: Long): LongOffset = new LongOffset(offset - decrement)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala
new file mode 100644
index 0000000000000..0f5d6445b1e2b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+/**
+ * A offset is a monotonically increasing metric used to track progress in the computation of a
+ * stream. An [[Offset]] must be comparable, and the result of `compareTo` must be consistent
+ * with `equals` and `hashcode`.
+ */
+trait Offset extends Serializable {
+
+ /**
+ * Returns a negative integer, zero, or a positive integer as this object is less than, equal to,
+ * or greater than the specified object.
+ */
+ def compareTo(other: Offset): Int
+
+ def >(other: Offset): Boolean = compareTo(other) > 0
+ def <(other: Offset): Boolean = compareTo(other) < 0
+ def <=(other: Offset): Boolean = compareTo(other) <= 0
+ def >=(other: Offset): Boolean = compareTo(other) >= 0
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala
new file mode 100644
index 0000000000000..1bd71b6b02ea9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+/**
+ * An interface for systems that can collect the results of a streaming query.
+ *
+ * When new data is produced by a query, a [[Sink]] must be able to transactionally collect the
+ * data and update the [[Offset]]. In the case of a failure, the sink will be recreated
+ * and must be able to return the [[Offset]] for all of the data that is made durable.
+ * This contract allows Spark to process data with exactly-once semantics, even in the case
+ * of failures that require the computation to be restarted.
+ */
+trait Sink {
+ /**
+ * Returns the [[Offset]] for all data that is currently present in the sink, if any. This
+ * function will be called by Spark when restarting execution in order to determine at which point
+ * in the input stream computation should be resumed from.
+ */
+ def currentOffset: Option[Offset]
+
+ /**
+ * Accepts a new batch of data as well as a [[Offset]] that denotes how far in the input
+ * data computation has progressed to. When computation restarts after a failure, it is important
+ * that a [[Sink]] returns the same [[Offset]] as the most recent batch of data that
+ * has been persisted durrably. Note that this does not necessarily have to be the
+ * [[Offset]] for the most recent batch of data that was given to the sink. For example,
+ * it is valid to buffer data before persisting, as long as the [[Offset]] is stored
+ * transactionally as data is eventually persisted.
+ */
+ def addBatch(batch: Batch): Unit
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala
new file mode 100644
index 0000000000000..25922979ac83e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A source of continually arriving data for a streaming query. A [[Source]] must have a
+ * monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark
+ * will regularly query each [[Source]] to see if any more data is available.
+ */
+trait Source {
+
+ /** Returns the schema of the data from this source */
+ def schema: StructType
+
+ /**
+ * Returns the next batch of data that is available after `start`, if any is available.
+ */
+ def getNextBatch(start: Option[Offset]): Option[Batch]
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
new file mode 100644
index 0000000000000..ebebb829710b2
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.lang.Thread.UncaughtExceptionHandler
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.{ContinuousQuery, DataFrame, SQLContext}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.execution.QueryExecution
+
+/**
+ * Manages the execution of a streaming Spark SQL query that is occurring in a separate thread.
+ * Unlike a standard query, a streaming query executes repeatedly each time new data arrives at any
+ * [[Source]] present in the query plan. Whenever new data arrives, a [[QueryExecution]] is created
+ * and the results are committed transactionally to the given [[Sink]].
+ */
+class StreamExecution(
+ sqlContext: SQLContext,
+ private[sql] val logicalPlan: LogicalPlan,
+ val sink: Sink) extends ContinuousQuery with Logging {
+
+ /** An monitor used to wait/notify when batches complete. */
+ private val awaitBatchLock = new Object
+
+ @volatile
+ private var batchRun = false
+
+ /** Minimum amount of time in between the start of each batch. */
+ private val minBatchTime = 10
+
+ /** Tracks how much data we have processed from each input source. */
+ private[sql] val streamProgress = new StreamProgress
+
+ /** All stream sources present the query plan. */
+ private val sources =
+ logicalPlan.collect { case s: StreamingRelation => s.source }
+
+ // Start the execution at the current offsets stored in the sink. (i.e. avoid reprocessing data
+ // that we have already processed).
+ {
+ sink.currentOffset match {
+ case Some(c: CompositeOffset) =>
+ val storedProgress = c.offsets
+ val sources = logicalPlan collect {
+ case StreamingRelation(source, _) => source
+ }
+
+ assert(sources.size == storedProgress.size)
+ sources.zip(storedProgress).foreach { case (source, offset) =>
+ offset.foreach(streamProgress.update(source, _))
+ }
+ case None => // We are starting this stream for the first time.
+ case _ => throw new IllegalArgumentException("Expected composite offset from sink")
+ }
+ }
+
+ logInfo(s"Stream running at $streamProgress")
+
+ /** When false, signals to the microBatchThread that it should stop running. */
+ @volatile private var shouldRun = true
+
+ /** The thread that runs the micro-batches of this stream. */
+ private[sql] val microBatchThread = new Thread("stream execution thread") {
+ override def run(): Unit = {
+ SQLContext.setActive(sqlContext)
+ while (shouldRun) {
+ attemptBatch()
+ Thread.sleep(minBatchTime) // TODO: Could be tighter
+ }
+ }
+ }
+ microBatchThread.setDaemon(true)
+ microBatchThread.setUncaughtExceptionHandler(
+ new UncaughtExceptionHandler {
+ override def uncaughtException(t: Thread, e: Throwable): Unit = {
+ streamDeathCause = e
+ }
+ })
+ microBatchThread.start()
+
+ @volatile
+ private[sql] var lastExecution: QueryExecution = null
+ @volatile
+ private[sql] var streamDeathCause: Throwable = null
+
+ /**
+ * Checks to see if any new data is present in any of the sources. When new data is available,
+ * a batch is executed and passed to the sink, updating the currentOffsets.
+ */
+ private def attemptBatch(): Unit = {
+ val startTime = System.nanoTime()
+
+ // A list of offsets that need to be updated if this batch is successful.
+ // Populated while walking the tree.
+ val newOffsets = new ArrayBuffer[(Source, Offset)]
+ // A list of attributes that will need to be updated.
+ var replacements = new ArrayBuffer[(Attribute, Attribute)]
+ // Replace sources in the logical plan with data that has arrived since the last batch.
+ val withNewSources = logicalPlan transform {
+ case StreamingRelation(source, output) =>
+ val prevOffset = streamProgress.get(source)
+ val newBatch = source.getNextBatch(prevOffset)
+
+ newBatch.map { batch =>
+ newOffsets += ((source, batch.end))
+ val newPlan = batch.data.logicalPlan
+
+ assert(output.size == newPlan.output.size)
+ replacements ++= output.zip(newPlan.output)
+ newPlan
+ }.getOrElse {
+ LocalRelation(output)
+ }
+ }
+
+ // Rewire the plan to use the new attributes that were returned by the source.
+ val replacementMap = AttributeMap(replacements)
+ val newPlan = withNewSources transformAllExpressions {
+ case a: Attribute if replacementMap.contains(a) => replacementMap(a)
+ }
+
+ if (newOffsets.nonEmpty) {
+ val optimizerStart = System.nanoTime()
+
+ lastExecution = new QueryExecution(sqlContext, newPlan)
+ val executedPlan = lastExecution.executedPlan
+ val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000
+ logDebug(s"Optimized batch in ${optimizerTime}ms")
+
+ streamProgress.synchronized {
+ // Update the offsets and calculate a new composite offset
+ newOffsets.foreach(streamProgress.update)
+ val newStreamProgress = logicalPlan.collect {
+ case StreamingRelation(source, _) => streamProgress.get(source)
+ }
+ val batchOffset = CompositeOffset(newStreamProgress)
+
+ // Construct the batch and send it to the sink.
+ val nextBatch = new Batch(batchOffset, new DataFrame(sqlContext, newPlan))
+ sink.addBatch(nextBatch)
+ }
+
+ batchRun = true
+ awaitBatchLock.synchronized {
+ // Wake up any threads that are waiting for the stream to progress.
+ awaitBatchLock.notifyAll()
+ }
+
+ val batchTime = (System.nanoTime() - startTime).toDouble / 1000000
+ logInfo(s"Compete up to $newOffsets in ${batchTime}ms")
+ }
+
+ logDebug(s"Waiting for data, current: $streamProgress")
+ }
+
+ /**
+ * Signals to the thread executing micro-batches that it should stop running after the next
+ * batch. This method blocks until the thread stops running.
+ */
+ def stop(): Unit = {
+ shouldRun = false
+ if (microBatchThread.isAlive) { microBatchThread.join() }
+ }
+
+ /**
+ * Blocks the current thread until processing for data from the given `source` has reached at
+ * least the given `Offset`. This method is indented for use primarily when writing tests.
+ */
+ def awaitOffset(source: Source, newOffset: Offset): Unit = {
+ def notDone = streamProgress.synchronized {
+ !streamProgress.contains(source) || streamProgress(source) < newOffset
+ }
+
+ while (notDone) {
+ logInfo(s"Waiting until $newOffset at $source")
+ awaitBatchLock.synchronized { awaitBatchLock.wait(100) }
+ }
+ logDebug(s"Unblocked at $newOffset for $source")
+ }
+
+ override def toString: String =
+ s"""
+ |=== Streaming Query ===
+ |CurrentOffsets: $streamProgress
+ |Thread State: ${microBatchThread.getState}
+ |${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""}
+ |
+ |$logicalPlan
+ """.stripMargin
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
new file mode 100644
index 0000000000000..0ded1d7152c19
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import scala.collection.mutable
+
+/**
+ * A helper class that looks like a Map[Source, Offset].
+ */
+class StreamProgress {
+ private val currentOffsets = new mutable.HashMap[Source, Offset]
+
+ private[streaming] def update(source: Source, newOffset: Offset): Unit = {
+ currentOffsets.get(source).foreach(old =>
+ assert(newOffset > old, s"Stream going backwards $newOffset -> $old"))
+ currentOffsets.put(source, newOffset)
+ }
+
+ private[streaming] def update(newOffset: (Source, Offset)): Unit =
+ update(newOffset._1, newOffset._2)
+
+ private[streaming] def apply(source: Source): Offset = currentOffsets(source)
+ private[streaming] def get(source: Source): Option[Offset] = currentOffsets.get(source)
+ private[streaming] def contains(source: Source): Boolean = currentOffsets.contains(source)
+
+ private[streaming] def ++(updates: Map[Source, Offset]): StreamProgress = {
+ val updated = new StreamProgress
+ currentOffsets.foreach(updated.update)
+ updates.foreach(updated.update)
+ updated
+ }
+
+ /**
+ * Used to create a new copy of this [[StreamProgress]]. While this class is currently mutable,
+ * it should be copied before being passed to user code.
+ */
+ private[streaming] def copy(): StreamProgress = {
+ val copied = new StreamProgress
+ currentOffsets.foreach(copied.update)
+ copied
+ }
+
+ override def toString: String =
+ currentOffsets.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}")
+
+ override def equals(other: Any): Boolean = other match {
+ case s: StreamProgress => currentOffsets == s.currentOffsets
+ case _ => false
+ }
+
+ override def hashCode: Int = currentOffsets.hashCode()
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
new file mode 100644
index 0000000000000..e35c444348f48
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LeafNode
+
+object StreamingRelation {
+ def apply(source: Source): StreamingRelation =
+ StreamingRelation(source, source.schema.toAttributes)
+}
+
+/**
+ * Used to link a streaming [[Source]] of data into a
+ * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]].
+ */
+case class StreamingRelation(source: Source, output: Seq[Attribute]) extends LeafNode {
+ override def toString: String = source.toString
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
new file mode 100644
index 0000000000000..e6a0842936ea2
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, RowEncoder}
+import org.apache.spark.sql.types.StructType
+
+object MemoryStream {
+ protected val currentBlockId = new AtomicInteger(0)
+ protected val memoryStreamId = new AtomicInteger(0)
+
+ def apply[A : Encoder](implicit sqlContext: SQLContext): MemoryStream[A] =
+ new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
+}
+
+/**
+ * A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]]
+ * is primarily intended for use in unit tests as it can only replay data when the object is still
+ * available.
+ */
+case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
+ extends Source with Logging {
+ protected val encoder = encoderFor[A]
+ protected val logicalPlan = StreamingRelation(this)
+ protected val output = logicalPlan.output
+ protected val batches = new ArrayBuffer[Dataset[A]]
+ protected var currentOffset: LongOffset = new LongOffset(-1)
+
+ protected def blockManager = SparkEnv.get.blockManager
+
+ def schema: StructType = encoder.schema
+
+ def getCurrentOffset: Offset = currentOffset
+
+ def toDS()(implicit sqlContext: SQLContext): Dataset[A] = {
+ new Dataset(sqlContext, logicalPlan)
+ }
+
+ def toDF()(implicit sqlContext: SQLContext): DataFrame = {
+ new DataFrame(sqlContext, logicalPlan)
+ }
+
+ def addData(data: TraversableOnce[A]): Offset = {
+ import sqlContext.implicits._
+ this.synchronized {
+ currentOffset = currentOffset + 1
+ val ds = data.toVector.toDS()
+ logDebug(s"Adding ds: $ds")
+ batches.append(ds)
+ currentOffset
+ }
+ }
+
+ override def getNextBatch(start: Option[Offset]): Option[Batch] = synchronized {
+ val newBlocks =
+ batches.drop(
+ start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1)
+
+ if (newBlocks.nonEmpty) {
+ logDebug(s"Running [$start, $currentOffset] on blocks ${newBlocks.mkString(", ")}")
+ val df = newBlocks
+ .map(_.toDF())
+ .reduceOption(_ unionAll _)
+ .getOrElse(sqlContext.emptyDataFrame)
+
+ Some(new Batch(currentOffset, df))
+ } else {
+ None
+ }
+ }
+
+ override def toString: String = s"MemoryStream[${output.mkString(",")}]"
+}
+
+/**
+ * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
+ * tests and does not provide durablility.
+ */
+class MemorySink(schema: StructType) extends Sink with Logging {
+ /** An order list of batches that have been written to this [[Sink]]. */
+ private var batches = new ArrayBuffer[Batch]()
+
+ /** Used to convert an [[InternalRow]] to an external [[Row]] for comparison in testing. */
+ private val externalRowConverter = RowEncoder(schema)
+
+ override def currentOffset: Option[Offset] = synchronized {
+ batches.lastOption.map(_.end)
+ }
+
+ override def addBatch(nextBatch: Batch): Unit = synchronized {
+ batches.append(nextBatch)
+ }
+
+ /** Returns all rows that are stored in this [[Sink]]. */
+ def allData: Seq[Row] = synchronized {
+ batches
+ .map(_.data)
+ .reduceOption(_ unionAll _)
+ .map(_.collect().toSeq)
+ .getOrElse(Seq.empty)
+ }
+
+ /**
+ * Atomically drops the most recent `num` batches and resets the [[StreamProgress]] to the
+ * corresponding point in the input. This function can be used when testing to simulate data
+ * that has been lost due to buffering.
+ */
+ def dropBatches(num: Int): Unit = synchronized {
+ batches.dropRight(num)
+ }
+
+ override def toString: String = synchronized {
+ batches.map(b => s"${b.end}: ${b.data.collect().mkString(" ")}").mkString("\n")
+ }
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index 544606f1168b6..835e7ba6c5168 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -23,7 +23,7 @@ import org.apache.spark.{JobExecutionStatus, Logging, SparkConf}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.scheduler._
import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution}
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricParam, SQLMetricValue}
+import org.apache.spark.sql.execution.metric._
import org.apache.spark.ui.SparkUI
@DeveloperApi
@@ -314,14 +314,17 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
}
+
+/**
+ * A [[SQLListener]] for rendering the SQL UI in the history server.
+ */
private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI)
extends SQLListener(conf) {
private var sqlTabAttached = false
- override def onExecutorMetricsUpdate(
- executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized {
- // Do nothing
+ override def onExecutorMetricsUpdate(u: SparkListenerExecutorMetricsUpdate): Unit = {
+ // Do nothing; these events are not logged
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
@@ -329,9 +332,15 @@ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI)
taskEnd.taskInfo.taskId,
taskEnd.stageId,
taskEnd.stageAttemptId,
- taskEnd.taskInfo.accumulables.map { a =>
- val newValue = new LongSQLMetricValue(a.update.map(_.asInstanceOf[Long]).getOrElse(0L))
- a.copy(update = Some(newValue))
+ taskEnd.taskInfo.accumulables.flatMap { a =>
+ // Filter out accumulators that are not SQL metrics
+ // For now we assume all SQL metrics are Long's that have been JSON serialized as String's
+ if (a.metadata.exists(_ == SQLMetrics.ACCUM_IDENTIFIER)) {
+ val newValue = new LongSQLMetricValue(a.update.map(_.toString.toLong).getOrElse(0L))
+ Some(a.copy(update = Some(newValue)))
+ } else {
+ None
+ }
},
finishTask = true)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 3a27466176a20..b970eee4e31a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -349,19 +349,51 @@ object functions extends LegacyFunctions {
}
/**
- * Aggregate function: returns the first value in a group.
- *
- * @group agg_funcs
- * @since 1.3.0
- */
- def first(e: Column): Column = withAggregateFunction { new First(e.expr) }
-
- /**
- * Aggregate function: returns the first value of a column in a group.
- *
- * @group agg_funcs
- * @since 1.3.0
- */
+ * Aggregate function: returns the first value in a group.
+ *
+ * The function by default returns the first values it sees. It will return the first non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
+ def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction {
+ new First(e.expr, Literal(ignoreNulls))
+ }
+
+ /**
+ * Aggregate function: returns the first value of a column in a group.
+ *
+ * The function by default returns the first values it sees. It will return the first non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
+ def first(columnName: String, ignoreNulls: Boolean): Column = {
+ first(Column(columnName), ignoreNulls)
+ }
+
+ /**
+ * Aggregate function: returns the first value in a group.
+ *
+ * The function by default returns the first values it sees. It will return the first non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * @group agg_funcs
+ * @since 1.3.0
+ */
+ def first(e: Column): Column = first(e, ignoreNulls = false)
+
+ /**
+ * Aggregate function: returns the first value of a column in a group.
+ *
+ * The function by default returns the first values it sees. It will return the first non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * @group agg_funcs
+ * @since 1.3.0
+ */
def first(columnName: String): Column = first(Column(columnName))
/**
@@ -381,20 +413,52 @@ object functions extends LegacyFunctions {
def kurtosis(columnName: String): Column = kurtosis(Column(columnName))
/**
- * Aggregate function: returns the last value in a group.
- *
- * @group agg_funcs
- * @since 1.3.0
- */
- def last(e: Column): Column = withAggregateFunction { new Last(e.expr) }
-
- /**
- * Aggregate function: returns the last value of the column in a group.
- *
- * @group agg_funcs
- * @since 1.3.0
- */
- def last(columnName: String): Column = last(Column(columnName))
+ * Aggregate function: returns the last value in a group.
+ *
+ * The function by default returns the last values it sees. It will return the last non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
+ def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction {
+ new Last(e.expr, Literal(ignoreNulls))
+ }
+
+ /**
+ * Aggregate function: returns the last value of the column in a group.
+ *
+ * The function by default returns the last values it sees. It will return the last non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
+ def last(columnName: String, ignoreNulls: Boolean): Column = {
+ last(Column(columnName), ignoreNulls)
+ }
+
+ /**
+ * Aggregate function: returns the last value in a group.
+ *
+ * The function by default returns the last values it sees. It will return the last non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * @group agg_funcs
+ * @since 1.3.0
+ */
+ def last(e: Column): Column = last(e, ignoreNulls = false)
+
+ /**
+ * Aggregate function: returns the last value of the column in a group.
+ *
+ * The function by default returns the last values it sees. It will return the last non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * @group agg_funcs
+ * @since 1.3.0
+ */
+ def last(columnName: String): Column = last(Column(columnName), ignoreNulls = false)
/**
* Aggregate function: returns the maximum value of the expression in a group.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 8911ad370aa7b..737be7dfd12f6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -35,8 +35,10 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.execution.{FileRelation, RDDConversions}
import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration
+import org.apache.spark.util.collection.BitSet
/**
* ::DeveloperApi::
@@ -123,6 +125,26 @@ trait SchemaRelationProvider {
schema: StructType): BaseRelation
}
+/**
+ * Implemented by objects that can produce a streaming [[Source]] for a specific format or system.
+ */
+trait StreamSourceProvider {
+ def createSource(
+ sqlContext: SQLContext,
+ parameters: Map[String, String],
+ schema: Option[StructType]): Source
+}
+
+/**
+ * Implemented by objects that can produce a streaming [[Sink]] for a specific format or system.
+ */
+trait StreamSinkProvider {
+ def createSink(
+ sqlContext: SQLContext,
+ parameters: Map[String, String],
+ partitionColumns: Seq[String]): Sink
+}
+
/**
* ::Experimental::
* Implemented by objects that produce relations for a specific kind of data source
@@ -701,6 +723,7 @@ abstract class HadoopFsRelation private[sql](
final private[sql] def buildInternalScan(
requiredColumns: Array[String],
filters: Array[Filter],
+ bucketSet: Option[BitSet],
inputPaths: Array[String],
broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = {
val inputStatuses = inputPaths.flatMap { input =>
@@ -722,9 +745,16 @@ abstract class HadoopFsRelation private[sql](
// id from file name. Then read these files into a RDD(use one-partition empty RDD for empty
// bucket), and coalesce it to one partition. Finally union all bucket RDDs to one result.
val perBucketRows = (0 until maybeBucketSpec.get.numBuckets).map { bucketId =>
- groupedBucketFiles.get(bucketId).map { inputStatuses =>
- buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1)
- }.getOrElse(sqlContext.emptyResult)
+ // If the current bucketId is not set in the bucket bitSet, skip scanning it.
+ if (bucketSet.nonEmpty && !bucketSet.get.get(bucketId)){
+ sqlContext.emptyResult
+ } else {
+ // When all the buckets need a scan (i.e., bucketSet is equal to None)
+ // or when the current bucket need a scan (i.e., the bit of bucketId is set to true)
+ groupedBucketFiles.get(bucketId).map { inputStatuses =>
+ buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1)
+ }.getOrElse(sqlContext.emptyResult)
+ }
}
new UnionRDD(sqlContext.sparkContext, perBucketRows)
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index a6fb62c17d59b..1181244c8a4ed 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -850,9 +850,7 @@ public void testRuntimeNullabilityCheck() {
}
nullabilityCheck.expect(RuntimeException.class);
- nullabilityCheck.expectMessage(
- "Null value appeared in non-nullable field " +
- "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int.");
+ nullabilityCheck.expectMessage("Null value appeared in non-nullable field");
{
Row row = new GenericRow(new Object[] {
diff --git a/sql/core/src/test/resources/cars-malformed.csv b/sql/core/src/test/resources/cars-malformed.csv
new file mode 100644
index 0000000000000..cfa378c01f1d9
--- /dev/null
+++ b/sql/core/src/test/resources/cars-malformed.csv
@@ -0,0 +1,6 @@
+~ All the rows here are malformed having tokens more than the schema (header).
+year,make,model,comment,blank
+"2012","Tesla","S","No comment",,null,null
+
+1997,Ford,E350,"Go get one now they are going fast",,null,null
+2015,Chevy,,,,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index b1004bc5bc290..08fb7c9d84c0b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -153,6 +153,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}
+ test("agg without groups and functions") {
+ checkAnswer(
+ testData2.agg(lit(1)),
+ Row(1)
+ )
+ }
+
test("average") {
checkAnswer(
testData2.agg(avg('a), mean('a)),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index c17be8ace9287..a5e5f156423cc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -42,6 +42,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil)
}
+ test("join - sorted columns not in join's outputSet") {
+ val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str_sort").as('df1)
+ val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df2)
+ val df3 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df3)
+
+ checkAnswer(
+ df.join(df2, $"df1.int" === $"df2.int", "outer").select($"df1.int", $"df2.int2")
+ .orderBy('str_sort.asc, 'str.asc),
+ Row(null, 6) :: Row(1, 3) :: Row(3, null) :: Nil)
+
+ checkAnswer(
+ df2.join(df3, $"df2.int" === $"df3.int", "inner")
+ .select($"df2.int", $"df3.int").orderBy($"df2.str".desc),
+ Row(5, 5) :: Row(1, 1) :: Nil)
+ }
+
test("join - join using multiple columns and specifying join type") {
val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str")
val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 09bbe57a43ceb..c02133ffc8540 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -349,6 +349,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row(3, "c") ::
Row(4, "d") :: Nil)
checkAnswer(lowerCaseData.intersect(upperCaseData), Nil)
+
+ // check null equality
+ checkAnswer(
+ nullInts.intersect(nullInts),
+ Row(1) ::
+ Row(2) ::
+ Row(3) ::
+ Row(null) :: Nil)
+
+ // check if values are de-duplicated
+ checkAnswer(
+ allNulls.intersect(allNulls),
+ Row(null) :: Nil)
+
+ // check if values are de-duplicated
+ val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value")
+ checkAnswer(
+ df.intersect(df),
+ Row("id1", 1) ::
+ Row("id", 1) ::
+ Row("id1", 2) :: Nil)
}
test("intersect - nullability") {
@@ -933,6 +954,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(expected === actual)
}
+ test("Sorting columns are not in Filter and Project") {
+ checkAnswer(
+ upperCaseData.filter('N > 1).select('N).filter('N < 6).orderBy('L.asc),
+ Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil)
+ }
+
test("SPARK-9323: DataFrame.orderBy should support nested column name") {
val df = sqlContext.read.json(sparkContext.makeRDD(
"""{"a": {"b": 1}}""" :: Nil))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala
index 09a56f6f3ae28..2bcbb1983f7ac 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala
@@ -312,4 +312,46 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext {
Row("b", 3, null, null),
Row("b", 2, null, null)))
}
+
+ test("last/first with ignoreNulls") {
+ val nullStr: String = null
+ val df = Seq(
+ ("a", 0, nullStr),
+ ("a", 1, "x"),
+ ("a", 2, "y"),
+ ("a", 3, "z"),
+ ("a", 4, nullStr),
+ ("b", 1, nullStr),
+ ("b", 2, nullStr)).
+ toDF("key", "order", "value")
+ val window = Window.partitionBy($"key").orderBy($"order")
+ checkAnswer(
+ df.select(
+ $"key",
+ $"order",
+ first($"value").over(window),
+ first($"value", ignoreNulls = false).over(window),
+ first($"value", ignoreNulls = true).over(window),
+ last($"value").over(window),
+ last($"value", ignoreNulls = false).over(window),
+ last($"value", ignoreNulls = true).over(window)),
+ Seq(
+ Row("a", 0, null, null, null, null, null, null),
+ Row("a", 1, null, null, "x", "x", "x", "x"),
+ Row("a", 2, null, null, "x", "y", "y", "y"),
+ Row("a", 3, null, null, "x", "z", "z", "z"),
+ Row("a", 4, null, null, "x", null, null, "z"),
+ Row("b", 1, null, null, null, null, null, null),
+ Row("b", 2, null, null, null, null, null, null)))
+ }
+
+ test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") {
+ val src = Seq((0, 3, 5)).toDF("a", "b", "c")
+ .withColumn("Data", struct("a", "b"))
+ .drop("a")
+ .drop("b")
+ val winSpec = Window.partitionBy("Data.a", "Data.b").orderBy($"c".desc)
+ val df = src.select($"*", max("c").over(winSpec) as "max")
+ checkAnswer(df, Row(5, Row(0, 3), 5))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index f75d0961823c4..243d13b19d6cd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -105,4 +105,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
agged,
"1", "abc", "3", "xyz", "5", "hello")
}
+
+ test("Arrays and Lists") {
+ checkAnswer(Seq(Seq(1)).toDS(), Seq(1))
+ checkAnswer(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong))
+ checkAnswer(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble))
+ checkAnswer(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat))
+ checkAnswer(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte))
+ checkAnswer(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort))
+ checkAnswer(Seq(Seq(true)).toDS(), Seq(true))
+ checkAnswer(Seq(Seq("test")).toDS(), Seq("test"))
+ checkAnswer(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1)))
+
+ checkAnswer(Seq(Array(1)).toDS(), Array(1))
+ checkAnswer(Seq(Array(1.toLong)).toDS(), Array(1.toLong))
+ checkAnswer(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble))
+ checkAnswer(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat))
+ checkAnswer(Seq(Array(1.toByte)).toDS(), Array(1.toByte))
+ checkAnswer(Seq(Array(1.toShort)).toDS(), Array(1.toShort))
+ checkAnswer(Seq(Array(true)).toDS(), Array(true))
+ checkAnswer(Seq(Array("test")).toDS(), Array("test"))
+ checkAnswer(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index b69bb21db532b..f9ba60770022d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
import scala.language.postfixOps
+import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
@@ -45,13 +46,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
1, 1, 1)
}
-
test("SPARK-12404: Datatype Helper Serializablity") {
val ds = sparkContext.parallelize((
- new Timestamp(0),
- new Date(0),
- java.math.BigDecimal.valueOf(1),
- scala.math.BigDecimal(1)) :: Nil).toDS()
+ new Timestamp(0),
+ new Date(0),
+ java.math.BigDecimal.valueOf(1),
+ scala.math.BigDecimal(1)) :: Nil).toDS()
ds.collect()
}
@@ -523,7 +523,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("verify mismatching field names fail with a good error") {
val ds = Seq(ClassData("a", 1)).toDS()
val e = intercept[AnalysisException] {
- ds.as[ClassData2].collect()
+ ds.as[ClassData2]
}
assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage)
}
@@ -553,9 +553,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
buildDataset(Row(Row("hello", null))).collect()
}.getMessage
- assert(message.contains(
- "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
- ))
+ assert(message.contains("Null value appeared in non-nullable field"))
}
test("SPARK-12478: top level null field") {
@@ -567,6 +565,58 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(ds1, DeepNestedStruct(NestedStruct(null)))
checkAnswer(ds1.toDF(), Row(Row(null)))
}
+
+ test("support inner class in Dataset") {
+ val outer = new OuterClass
+ OuterScopes.addOuterScope(outer)
+ val ds = Seq(outer.InnerClass("1"), outer.InnerClass("2")).toDS()
+ checkAnswer(ds.map(_.a), "1", "2")
+ }
+
+ test("grouping key and grouped value has field with same name") {
+ val ds = Seq(ClassData("a", 1), ClassData("a", 2)).toDS()
+ val agged = ds.groupBy(d => ClassNullableData(d.a, null)).mapGroups {
+ case (key, values) => key.a + values.map(_.b).sum
+ }
+
+ checkAnswer(agged, "a3")
+ }
+
+ test("cogroup's left and right side has field with same name") {
+ val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+ val right = Seq(ClassNullableData("a", 3), ClassNullableData("b", 4)).toDS()
+ val cogrouped = left.groupBy(_.a).cogroup(right.groupBy(_.a)) {
+ case (key, lData, rData) => Iterator(key + lData.map(_.b).sum + rData.map(_.b.toInt).sum)
+ }
+
+ checkAnswer(cogrouped, "a13", "b24")
+ }
+
+ test("give nice error message when the real number of fields doesn't match encoder schema") {
+ val ds = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+
+ val message = intercept[AnalysisException] {
+ ds.as[(String, Int, Long)]
+ }.message
+ assert(message ==
+ "Try to map struct to Tuple3, " +
+ "but failed as the number of fields does not line up.\n" +
+ " - Input schema: struct\n" +
+ " - Target schema: struct<_1:string,_2:int,_3:bigint>")
+
+ val message2 = intercept[AnalysisException] {
+ ds.as[Tuple1[String]]
+ }.message
+ assert(message2 ==
+ "Try to map struct to Tuple1, " +
+ "but failed as the number of fields does not line up.\n" +
+ " - Input schema: struct\n" +
+ " - Target schema: struct<_1:string>")
+ }
+}
+
+class OuterClass extends Serializable {
+ case class InnerClass(a: String)
}
case class ClassData(a: String, b: Int)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index ce12f788b786c..5401212428d6f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -95,7 +95,13 @@ abstract class QueryTest extends PlanTest {
""".stripMargin, e)
}
- if (decoded != expectedAnswer.toSet) {
+ // Handle the case where the return type is an array
+ val isArray = decoded.headOption.map(_.getClass.isArray).getOrElse(false)
+ def normalEquality = decoded == expectedAnswer.toSet
+ def expectedAsSeq = expectedAnswer.map(_.asInstanceOf[Array[_]].toSeq).toSet
+ def decodedAsSeq = decoded.map(_.asInstanceOf[Array[_]].toSeq)
+
+ if (!((isArray && expectedAsSeq == decodedAsSeq) || normalEquality)) {
val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted
val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted
@@ -304,27 +310,7 @@ object QueryTest {
def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
- // We need to call prepareRow recursively to handle schemas with struct types.
- def prepareRow(row: Row): Row = {
- Row.fromSeq(row.toSeq.map {
- case null => null
- case d: java.math.BigDecimal => BigDecimal(d)
- // Convert array to Seq for easy equality check.
- case b: Array[_] => b.toSeq
- case r: Row => prepareRow(r)
- case o => o
- })
- }
- def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
- // Converts data to types that we can do equality comparison using Scala collections.
- // For BigDecimal type, the Scala type has a better definition of equality test (similar to
- // Java's java.math.BigDecimal.compareTo).
- // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
- // equality test.
- val converted: Seq[Row] = answer.map(prepareRow)
- if (!isSorted) converted.sortBy(_.toString()) else converted
- }
val sparkAnswer = try df.collect().toSeq catch {
case e: Exception =>
val errorMessage =
@@ -338,22 +324,56 @@ object QueryTest {
return Some(errorMessage)
}
- if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
- val errorMessage =
+ sameRows(expectedAnswer, sparkAnswer, isSorted).map { results =>
s"""
|Results do not match for query:
|${df.queryExecution}
|== Results ==
- |${sideBySide(
- s"== Correct Answer - ${expectedAnswer.size} ==" +:
- prepareAnswer(expectedAnswer).map(_.toString()),
- s"== Spark Answer - ${sparkAnswer.size} ==" +:
- prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
- """.stripMargin
- return Some(errorMessage)
+ |$results
+ """.stripMargin
}
+ }
+
+
+ def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = {
+ // Converts data to types that we can do equality comparison using Scala collections.
+ // For BigDecimal type, the Scala type has a better definition of equality test (similar to
+ // Java's java.math.BigDecimal.compareTo).
+ // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
+ // equality test.
+ val converted: Seq[Row] = answer.map(prepareRow)
+ if (!isSorted) converted.sortBy(_.toString()) else converted
+ }
+
+ // We need to call prepareRow recursively to handle schemas with struct types.
+ def prepareRow(row: Row): Row = {
+ Row.fromSeq(row.toSeq.map {
+ case null => null
+ case d: java.math.BigDecimal => BigDecimal(d)
+ // Convert array to Seq for easy equality check.
+ case b: Array[_] => b.toSeq
+ case r: Row => prepareRow(r)
+ case o => o
+ })
+ }
- return None
+ def sameRows(
+ expectedAnswer: Seq[Row],
+ sparkAnswer: Seq[Row],
+ isSorted: Boolean = false): Option[String] = {
+ if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) {
+ val errorMessage =
+ s"""
+ |== Results ==
+ |${sideBySide(
+ s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer, isSorted).map(_.toString()),
+ s"== Spark Answer - ${sparkAnswer.size} ==" +:
+ prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")}
+ """.stripMargin
+ return Some(errorMessage)
+ }
+ None
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 989cb2942918e..8ef7b61314a56 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -84,7 +84,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
"Extended Usage")
checkExistence(sql("describe functioN abcadf"), true,
- "Function: abcadf is not found.")
+ "Function: abcadf not found.")
}
test("SPARK-6743: no columns from cache") {
@@ -1939,58 +1939,61 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("Common subexpression elimination") {
- // select from a table to prevent constant folding.
- val df = sql("SELECT a, b from testData2 limit 1")
- checkAnswer(df, Row(1, 1))
-
- checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2))
- checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3))
-
- // This does not work because the expressions get grouped like (a + a) + 1
- checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3))
- checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3))
-
- // Identity udf that tracks the number of times it is called.
- val countAcc = sparkContext.accumulator(0, "CallCount")
- sqlContext.udf.register("testUdf", (x: Int) => {
- countAcc.++=(1)
- x
- })
+ // TODO: support subexpression elimination in whole stage codegen
+ withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
+ // select from a table to prevent constant folding.
+ val df = sql("SELECT a, b from testData2 limit 1")
+ checkAnswer(df, Row(1, 1))
+
+ checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2))
+ checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3))
+
+ // This does not work because the expressions get grouped like (a + a) + 1
+ checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3))
+ checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3))
+
+ // Identity udf that tracks the number of times it is called.
+ val countAcc = sparkContext.accumulator(0, "CallCount")
+ sqlContext.udf.register("testUdf", (x: Int) => {
+ countAcc.++=(1)
+ x
+ })
+
+ // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value
+ // is correct.
+ def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = {
+ countAcc.setValue(0)
+ checkAnswer(df, expectedResult)
+ assert(countAcc.value == expectedCount)
+ }
- // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value
- // is correct.
- def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = {
- countAcc.setValue(0)
- checkAnswer(df, expectedResult)
- assert(countAcc.value == expectedCount)
+ verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1)
+ verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
+ verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1)
+ verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2)
+ verifyCallCount(
+ df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1)
+
+ verifyCallCount(
+ df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2)
+
+ val testUdf = functions.udf((x: Int) => {
+ countAcc.++=(1)
+ x
+ })
+ verifyCallCount(
+ df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
+
+ // Would be nice if semantic equals for `+` understood commutative
+ verifyCallCount(
+ df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)
+
+ // Try disabling it via configuration.
+ sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")
+ verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2)
+ sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
+ verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
}
-
- verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1)
- verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
- verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1)
- verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2)
- verifyCallCount(
- df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1)
-
- verifyCallCount(
- df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2)
-
- val testUdf = functions.udf((x: Int) => {
- countAcc.++=(1)
- x
- })
- verifyCallCount(
- df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
-
- // Would be nice if semantic equals for `+` understood commutative
- verifyCallCount(
- df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)
-
- // Try disabling it via configuration.
- sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")
- verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2)
- sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
- verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
}
test("SPARK-10707: nullability should be correctly propagated through set operations (1)") {
@@ -2052,6 +2055,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}
+ test("SPARK-13056: Null in map value causes NPE") {
+ val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value")
+ withTempTable("maptest") {
+ df.registerTempTable("maptest")
+ // local optimization will by pass codegen code, so we should keep the filter `key=1`
+ checkAnswer(sql("SELECT value['abc'] FROM maptest where key = 1"), Row("somestring"))
+ checkAnswer(sql("SELECT value['cba'] FROM maptest where key = 1"), Row(null))
+ }
+ }
+
test("hash function") {
val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
withTempTable("tbl") {
@@ -2062,4 +2075,28 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}
}
+
+ test("natural join") {
+ val df1 = Seq(("one", 1), ("two", 2), ("three", 3)).toDF("k", "v1")
+ val df2 = Seq(("one", 1), ("two", 22), ("one", 5)).toDF("k", "v2")
+ withTempTable("nt1", "nt2") {
+ df1.registerTempTable("nt1")
+ df2.registerTempTable("nt2")
+ checkAnswer(
+ sql("SELECT * FROM nt1 natural join nt2 where k = \"one\""),
+ Row("one", 1, 1) :: Row("one", 1, 5) :: Nil)
+
+ checkAnswer(
+ sql("SELECT * FROM nt1 natural left join nt2 order by v1, v2"),
+ Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Row("three", 3, null) :: Nil)
+
+ checkAnswer(
+ sql("SELECT * FROM nt1 natural right join nt2 order by v1, v2"),
+ Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Nil)
+
+ checkAnswer(
+ sql("SELECT count(*) FROM nt1 natural full outer join nt2"),
+ Row(4) :: Nil)
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
new file mode 100644
index 0000000000000..f45abbf2496a2
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
@@ -0,0 +1,346 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.lang.Thread.UncaughtExceptionHandler
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.Random
+
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.execution.streaming._
+
+/**
+ * A framework for implementing tests for streaming queries and sources.
+ *
+ * A test consists of a set of steps (expressed as a `StreamAction`) that are executed in order,
+ * blocking as necessary to let the stream catch up. For example, the following adds some data to
+ * a stream, blocking until it can verify that the correct values are eventually produced.
+ *
+ * {{{
+ * val inputData = MemoryStream[Int]
+ val mapped = inputData.toDS().map(_ + 1)
+
+ testStream(mapped)(
+ AddData(inputData, 1, 2, 3),
+ CheckAnswer(2, 3, 4))
+ * }}}
+ *
+ * Note that while we do sleep to allow the other thread to progress without spinning,
+ * `StreamAction` checks should not depend on the amount of time spent sleeping. Instead they
+ * should check the actual progress of the stream before verifying the required test condition.
+ *
+ * Currently it is assumed that all streaming queries will eventually complete in 10 seconds to
+ * avoid hanging forever in the case of failures. However, individual suites can change this
+ * by overriding `streamingTimeout`.
+ */
+trait StreamTest extends QueryTest with Timeouts {
+
+ implicit class RichSource(s: Source) {
+ def toDF(): DataFrame = new DataFrame(sqlContext, StreamingRelation(s))
+ }
+
+ /** How long to wait for an active stream to catch up when checking a result. */
+ val streamingTimout = 10.seconds
+
+ /** A trait for actions that can be performed while testing a streaming DataFrame. */
+ trait StreamAction
+
+ /** A trait to mark actions that require the stream to be actively running. */
+ trait StreamMustBeRunning
+
+ /**
+ * Adds the given data to the stream. Subsuquent check answers will block until this data has
+ * been processed.
+ */
+ object AddData {
+ def apply[A](source: MemoryStream[A], data: A*): AddDataMemory[A] =
+ AddDataMemory(source, data)
+ }
+
+ /** A trait that can be extended when testing other sources. */
+ trait AddData extends StreamAction {
+ def source: Source
+
+ /**
+ * Called to trigger adding the data. Should return the offset that will denote when this
+ * new data has been processed.
+ */
+ def addData(): Offset
+ }
+
+ case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData {
+ override def toString: String = s"AddData to $source: ${data.mkString(",")}"
+
+ override def addData(): Offset = {
+ source.addData(data)
+ }
+ }
+
+ /**
+ * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`.
+ * This operation automatically blocks untill all added data has been processed.
+ */
+ object CheckAnswer {
+ def apply[A : Encoder](data: A*): CheckAnswerRows = {
+ val encoder = encoderFor[A]
+ val toExternalRow = RowEncoder(encoder.schema)
+ CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))))
+ }
+
+ def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows)
+ }
+
+ case class CheckAnswerRows(expectedAnswer: Seq[Row])
+ extends StreamAction with StreamMustBeRunning {
+ override def toString: String = s"CheckAnswer: ${expectedAnswer.mkString(",")}"
+ }
+
+ case class DropBatches(num: Int) extends StreamAction
+
+ /** Stops the stream. It must currently be running. */
+ case object StopStream extends StreamAction with StreamMustBeRunning
+
+ /** Starts the stream, resuming if data has already been processed. It must not be running. */
+ case object StartStream extends StreamAction
+
+ /** Signals that a failure is expected and should not kill the test. */
+ case object ExpectFailure extends StreamAction
+
+ /** A helper for running actions on a Streaming Dataset. See `checkAnswer(DataFrame)`. */
+ def testStream(stream: Dataset[_])(actions: StreamAction*): Unit =
+ testStream(stream.toDF())(actions: _*)
+
+ /**
+ * Executes the specified actions on the the given streaming DataFrame and provides helpful
+ * error messages in the case of failures or incorrect answers.
+ *
+ * Note that if the stream is not explictly started before an action that requires it to be
+ * running then it will be automatically started before performing any other actions.
+ */
+ def testStream(stream: DataFrame)(actions: StreamAction*): Unit = {
+ var pos = 0
+ var currentPlan: LogicalPlan = stream.logicalPlan
+ var currentStream: StreamExecution = null
+ val awaiting = new mutable.HashMap[Source, Offset]()
+ val sink = new MemorySink(stream.schema)
+
+ @volatile
+ var streamDeathCause: Throwable = null
+
+ // If the test doesn't manually start the stream, we do it automatically at the beginning.
+ val startedManually =
+ actions.takeWhile(!_.isInstanceOf[StreamMustBeRunning]).contains(StartStream)
+ val startedTest = if (startedManually) actions else StartStream +: actions
+
+ def testActions = actions.zipWithIndex.map {
+ case (a, i) =>
+ if ((pos == i && startedManually) || (pos == (i + 1) && !startedManually)) {
+ "=> " + a.toString
+ } else {
+ " " + a.toString
+ }
+ }.mkString("\n")
+
+ def currentOffsets =
+ if (currentStream != null) currentStream.streamProgress.toString else "not started"
+
+ def threadState =
+ if (currentStream != null && currentStream.microBatchThread.isAlive) "alive" else "dead"
+ def testState =
+ s"""
+ |== Progress ==
+ |$testActions
+ |
+ |== Stream ==
+ |Stream state: $currentOffsets
+ |Thread state: $threadState
+ |${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""}
+ |
+ |== Sink ==
+ |$sink
+ |
+ |== Plan ==
+ |${if (currentStream != null) currentStream.lastExecution else ""}
+ """
+
+ def checkState(check: Boolean, error: String) = if (!check) {
+ fail(
+ s"""
+ |Invalid State: $error
+ |$testState
+ """.stripMargin)
+ }
+
+ val testThread = Thread.currentThread()
+
+ try {
+ startedTest.foreach { action =>
+ action match {
+ case StartStream =>
+ checkState(currentStream == null, "stream already running")
+
+ currentStream = new StreamExecution(sqlContext, stream.logicalPlan, sink)
+ currentStream.microBatchThread.setUncaughtExceptionHandler(
+ new UncaughtExceptionHandler {
+ override def uncaughtException(t: Thread, e: Throwable): Unit = {
+ streamDeathCause = e
+ testThread.interrupt()
+ }
+ })
+
+ case StopStream =>
+ checkState(currentStream != null, "can not stop a stream that is not running")
+ currentStream.stop()
+ currentStream = null
+
+ case DropBatches(num) =>
+ checkState(currentStream == null, "dropping batches while running leads to corruption")
+ sink.dropBatches(num)
+
+ case ExpectFailure =>
+ try failAfter(streamingTimout) {
+ while (streamDeathCause == null) {
+ Thread.sleep(100)
+ }
+ } catch {
+ case _: InterruptedException =>
+ case _: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
+ fail(
+ s"""
+ |Timed out while waiting for failure.
+ |$testState
+ """.stripMargin)
+ }
+
+ currentStream = null
+ streamDeathCause = null
+
+ case a: AddData =>
+ awaiting.put(a.source, a.addData())
+
+ case CheckAnswerRows(expectedAnswer) =>
+ checkState(currentStream != null, "stream not running")
+
+ // Block until all data added has been processed
+ awaiting.foreach { case (source, offset) =>
+ failAfter(streamingTimout) {
+ currentStream.awaitOffset(source, offset)
+ }
+ }
+
+ val allData = try sink.allData catch {
+ case e: Exception =>
+ fail(
+ s"""
+ |Exception while getting data from sink $e
+ |$testState
+ """.stripMargin)
+ }
+
+ QueryTest.sameRows(expectedAnswer, allData).foreach {
+ error => fail(
+ s"""
+ |$error
+ |$testState
+ """.stripMargin)
+ }
+ }
+ pos += 1
+ }
+ } catch {
+ case _: InterruptedException if streamDeathCause != null =>
+ fail(
+ s"""
+ |Stream Thread Died
+ |$testState
+ """.stripMargin)
+ case _: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
+ fail(
+ s"""
+ |Timed out waiting for stream
+ |$testState
+ """.stripMargin)
+ } finally {
+ if (currentStream != null && currentStream.microBatchThread.isAlive) {
+ currentStream.stop()
+ }
+ }
+ }
+
+ /**
+ * Creates a stress test that randomly starts/stops/adds data/checks the result.
+ *
+ * @param ds a dataframe that executes + 1 on a stream of integers, returning the result.
+ * @param addData and add data action that adds the given numbers to the stream, encoding them
+ * as needed
+ */
+ def runStressTest(
+ ds: Dataset[Int],
+ addData: Seq[Int] => StreamAction,
+ iterations: Int = 100): Unit = {
+ implicit val intEncoder = ExpressionEncoder[Int]()
+ var dataPos = 0
+ var running = true
+ val actions = new ArrayBuffer[StreamAction]()
+
+ def addCheck() = { actions += CheckAnswer(1 to dataPos: _*) }
+
+ def addRandomData() = {
+ val numItems = Random.nextInt(10)
+ val data = dataPos until (dataPos + numItems)
+ dataPos += numItems
+ actions += addData(data)
+ }
+
+ (1 to iterations).foreach { i =>
+ val rand = Random.nextDouble()
+ if(!running) {
+ rand match {
+ case r if r < 0.7 => // AddData
+ addRandomData()
+
+ case _ => // StartStream
+ actions += StartStream
+ running = true
+ }
+ } else {
+ rand match {
+ case r if r < 0.1 =>
+ addCheck()
+
+ case r if r < 0.7 => // AddData
+ addRandomData()
+
+ case _ => // StartStream
+ actions += StopStream
+ running = false
+ }
+ }
+ }
+ if(!running) { actions += StartStream }
+ addCheck()
+ testStream(ds)(actions: _*)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index c4aad398bfa54..33d4976403d9a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -18,7 +18,13 @@
package org.apache.spark.sql.execution
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.functions._
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.hash.Murmur3_x86_32
+import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.util.Benchmark
/**
@@ -27,34 +33,176 @@ import org.apache.spark.util.Benchmark
* build/sbt "sql/test-only *BenchmarkWholeStageCodegen"
*/
class BenchmarkWholeStageCodegen extends SparkFunSuite {
- def testWholeStage(values: Int): Unit = {
- val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
- val sc = SparkContext.getOrCreate(conf)
- val sqlContext = SQLContext.getOrCreate(sc)
+ lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
+ .set("spark.sql.shuffle.partitions", "1")
+ lazy val sc = SparkContext.getOrCreate(conf)
+ lazy val sqlContext = SQLContext.getOrCreate(sc)
- val benchmark = new Benchmark("Single Int Column Scan", values)
+ def runBenchmark(name: String, values: Int)(f: => Unit): Unit = {
+ val benchmark = new Benchmark(name, values)
- benchmark.addCase("Without whole stage codegen") { iter =>
- sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
- sqlContext.range(values).filter("(id & 1) = 1").count()
+ Seq(false, true).foreach { enabled =>
+ benchmark.addCase(s"$name codegen=$enabled") { iter =>
+ sqlContext.setConf("spark.sql.codegen.wholeStage", enabled.toString)
+ f
+ }
}
- benchmark.addCase("With whole stage codegen") { iter =>
- sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
- sqlContext.range(values).filter("(id & 1) = 1").count()
- }
+ benchmark.run()
+ }
+ // These benchmark are skipped in normal build
+ ignore("range/filter/sum") {
+ val N = 500 << 20
+ runBenchmark("rang/filter/sum", N) {
+ sqlContext.range(N).filter("(id & 1) = 1").groupBy().sum().collect()
+ }
/*
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ rang/filter/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ rang/filter/sum codegen=false 14332 / 16646 36.0 27.8 1.0X
+ rang/filter/sum codegen=true 845 / 940 620.0 1.6 17.0X
+ */
+ }
+
+ ignore("stat functions") {
+ val N = 100 << 20
+
+ runBenchmark("stddev", N) {
+ sqlContext.range(N).groupBy().agg("id" -> "stddev").collect()
+ }
+
+ runBenchmark("kurtosis", N) {
+ sqlContext.range(N).groupBy().agg("id" -> "kurtosis").collect()
+ }
+
+
+ /**
+ Using ImperativeAggregate (as implemented in Spark 1.6):
+
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
- Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
+ stddev: Avg Time(ms) Avg Rate(M/s) Relative Rate
-------------------------------------------------------------------------------
- Without whole stage codegen 7775.53 26.97 1.00 X
- With whole stage codegen 342.15 612.94 22.73 X
+ stddev w/o codegen 2019.04 10.39 1.00 X
+ stddev w codegen 2097.29 10.00 0.96 X
+ kurtosis w/o codegen 2108.99 9.94 0.96 X
+ kurtosis w codegen 2090.69 10.03 0.97 X
+
+ Using DeclarativeAggregate:
+
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ stddev: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ stddev codegen=false 5630 / 5776 18.0 55.6 1.0X
+ stddev codegen=true 1259 / 1314 83.0 12.0 4.5X
+
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ kurtosis: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ kurtosis codegen=false 14847 / 15084 7.0 142.9 1.0X
+ kurtosis codegen=true 1652 / 2124 63.0 15.9 9.0X
+ */
+ }
+
+ ignore("aggregate with keys") {
+ val N = 20 << 20
+
+ runBenchmark("Aggregate w keys", N) {
+ sqlContext.range(N).selectExpr("(id & 65535) as k").groupBy("k").sum().collect()
+ }
+
+ /*
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Aggregate w keys codegen=false 2402 / 2551 8.0 125.0 1.0X
+ Aggregate w keys codegen=true 1620 / 1670 12.0 83.3 1.5X
+ */
+ }
+
+ ignore("broadcast hash join") {
+ val N = 20 << 20
+ val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v"))
+
+ runBenchmark("BroadcastHashJoin", N) {
+ sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count()
+ }
+
+ /*
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ BroadcastHashJoin codegen=false 4405 / 6147 4.0 250.0 1.0X
+ BroadcastHashJoin codegen=true 1857 / 1878 11.0 90.9 2.4X
*/
- benchmark.run()
}
- ignore("benchmark") {
- testWholeStage(1024 * 1024 * 200)
+ ignore("hash and BytesToBytesMap") {
+ val N = 50 << 20
+
+ val benchmark = new Benchmark("BytesToBytesMap", N)
+
+ benchmark.addCase("hash") { iter =>
+ var i = 0
+ val keyBytes = new Array[Byte](16)
+ val valueBytes = new Array[Byte](16)
+ val key = new UnsafeRow(1)
+ key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ val value = new UnsafeRow(2)
+ value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ var s = 0
+ while (i < N) {
+ key.setInt(0, i % 1000)
+ val h = Murmur3_x86_32.hashUnsafeWords(
+ key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0)
+ s += h
+ i += 1
+ }
+ }
+
+ Seq("off", "on").foreach { heap =>
+ benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
+ val taskMemoryManager = new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", s"${heap == "off"}")
+ .set("spark.memory.offHeap.size", "102400000"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+ val map = new BytesToBytesMap(taskMemoryManager, 1024, 64L<<20)
+ val keyBytes = new Array[Byte](16)
+ val valueBytes = new Array[Byte](16)
+ val key = new UnsafeRow(1)
+ key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ val value = new UnsafeRow(2)
+ value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ var i = 0
+ while (i < N) {
+ key.setInt(0, i % 65536)
+ val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
+ if (loc.isDefined) {
+ value.pointTo(loc.getValueAddress.getBaseObject, loc.getValueAddress.getBaseOffset,
+ loc.getValueLength)
+ value.setInt(0, value.getInt(0) + 1)
+ i += 1
+ } else {
+ loc.putNewKey(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+ value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
+ }
+ }
+ }
+ }
+
+ /**
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ hash 628 / 661 83.0 12.0 1.0X
+ BytesToBytesMap (off Heap) 3292 / 3408 15.0 66.7 0.2X
+ BytesToBytesMap (on Heap) 3349 / 4267 15.0 66.7 0.2X
+ */
+ benchmark.run()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 8fca5e2167d04..a64ad4038c7c3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -21,8 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, Row, SQLConf}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
import org.apache.spark.sql.functions._
@@ -182,6 +181,12 @@ class PlannerSuite extends SharedSQLContext {
}
}
+ test("terminal limits use CollectLimit") {
+ val query = testData.select('value).limit(2)
+ val planned = query.queryExecution.sparkPlan
+ assert(planned.isInstanceOf[CollectLimit])
+ }
+
test("PartitioningCollection") {
withTempTable("normal", "small", "tiny") {
testData.registerTempTable("normal")
@@ -201,7 +206,7 @@ class PlannerSuite extends SharedSQLContext {
).queryExecution.executedPlan.collect {
case exchange: Exchange => exchange
}.length
- assert(numExchanges === 3)
+ assert(numExchanges === 5)
}
{
@@ -216,13 +221,25 @@ class PlannerSuite extends SharedSQLContext {
).queryExecution.executedPlan.collect {
case exchange: Exchange => exchange
}.length
- assert(numExchanges === 3)
+ assert(numExchanges === 5)
}
}
}
}
+ test("collapse adjacent repartitions") {
+ val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5)
+ def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length
+ assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3)
+ assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 1)
+ doubleRepartitioned.queryExecution.optimizedPlan match {
+ case r: Repartition =>
+ assert(r.numPartitions === 5)
+ assert(r.shuffle === false)
+ }
+ }
+
// --- Unit tests of EnsureRequirements ---------------------------------------------------------
// When it comes to testing whether EnsureRequirements properly ensures distribution requirements,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
index 6259453da26a1..cb6d68dc3ac46 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -56,8 +56,8 @@ class SortSuite extends SparkPlanTest with SharedSQLContext {
test("sort followed by limit") {
checkThatPlansAgree(
(1 to 100).map(v => Tuple1(v)).toDF("a"),
- (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child = child)),
- (child: SparkPlan) => Limit(10, ReferenceSort('a.asc :: Nil, global = true, child)),
+ (child: SparkPlan) => GlobalLimit(10, Sort('a.asc :: Nil, global = true, child = child)),
+ (child: SparkPlan) => GlobalLimit(10, ReferenceSort('a.asc :: Nil, global = true, child)),
sortAnswers = false
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 300788c88ab2f..9350205d791d7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -20,8 +20,10 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
-import org.apache.spark.sql.functions.{avg, col, max}
+import org.apache.spark.sql.execution.joins.BroadcastHashJoin
+import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
@@ -47,4 +49,24 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
assert(df.collect() === Array(Row(9, 4.5)))
}
+
+ test("Aggregate with grouping keys should be included in WholeStageCodegen") {
+ val df = sqlContext.range(3).groupBy("id").count().orderBy("id")
+ val plan = df.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
+ assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
+ }
+
+ test("BroadcastHashJoin should be included in WholeStageCodegen") {
+ val rdd = sqlContext.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
+ val schema = new StructType().add("k", IntegerType).add("v", StringType)
+ val smallDF = sqlContext.createDataFrame(rdd, schema)
+ val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id"))
+ assert(df.queryExecution.executedPlan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined)
+ assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index a79566b1f3658..fa4f137b703b4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.types._
class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val carsFile = "cars.csv"
+ private val carsMalformedFile = "cars-malformed.csv"
private val carsFile8859 = "cars_iso-8859-1.csv"
private val carsTsvFile = "cars.tsv"
private val carsAltFile = "cars-alternative.csv"
@@ -191,6 +192,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt"))
}
+ test("test for tokens more than the fields in the schema") {
+ val cars = sqlContext
+ .read
+ .format("csv")
+ .option("header", "false")
+ .option("comment", "~")
+ .load(testFile(carsMalformedFile))
+
+ verifyCars(cars, withHeader = false, checkTypes = false)
+ }
+
test("test with null quote character") {
val cars = sqlContext.read
.format("csv")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 00eaeb0d34e87..dd83a0e36f6f7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -771,6 +771,34 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
)
}
+ test("Loading a JSON dataset floatAsBigDecimal returns schema with float types as BigDecimal") {
+ val jsonDF = sqlContext.read.option("floatAsBigDecimal", "true").json(primitiveFieldAndType)
+
+ val expectedSchema = StructType(
+ StructField("bigInteger", DecimalType(20, 0), true) ::
+ StructField("boolean", BooleanType, true) ::
+ StructField("double", DecimalType(17, -292), true) ::
+ StructField("integer", LongType, true) ::
+ StructField("long", LongType, true) ::
+ StructField("null", StringType, true) ::
+ StructField("string", StringType, true) :: Nil)
+
+ assert(expectedSchema === jsonDF.schema)
+
+ jsonDF.registerTempTable("jsonTable")
+
+ checkAnswer(
+ sql("select * from jsonTable"),
+ Row(BigDecimal("92233720368547758070"),
+ true,
+ BigDecimal("1.7976931348623157E308"),
+ 10,
+ 21474836470L,
+ null,
+ "this is a simple string.")
+ )
+ }
+
test("Loading a JSON dataset from a text file with SQL") {
val dir = Utils.createTempDir()
dir.delete()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
new file mode 100644
index 0000000000000..cef6b79a094d1
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils
+import org.apache.spark.sql.test.SharedSQLContext
+
+// TODO: this needs a lot more testing but it's currently not easy to test with the parquet
+// writer abstractions. Revisit.
+class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContext {
+ import testImplicits._
+
+ val ROW = ((1).toByte, 2, 3L, "abc")
+ val NULL_ROW = (
+ null.asInstanceOf[java.lang.Byte],
+ null.asInstanceOf[Integer],
+ null.asInstanceOf[java.lang.Long],
+ null.asInstanceOf[String])
+
+ test("All Types Dictionary") {
+ (1 :: 1000 :: Nil).foreach { n => {
+ withTempPath { dir =>
+ List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath)
+ val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head
+
+ val reader = new UnsafeRowParquetRecordReader
+ reader.initialize(file.asInstanceOf[String], null)
+ val batch = reader.resultBatch()
+ assert(reader.nextBatch())
+ assert(batch.numRows() == n)
+ var i = 0
+ while (i < n) {
+ assert(batch.column(0).getByte(i) == 1)
+ assert(batch.column(1).getInt(i) == 2)
+ assert(batch.column(2).getLong(i) == 3)
+ assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(i)) == "abc")
+ i += 1
+ }
+ reader.close()
+ }
+ }}
+ }
+
+ test("All Types Null") {
+ (1 :: 100 :: Nil).foreach { n => {
+ withTempPath { dir =>
+ val data = List.fill(n)(NULL_ROW).toDF
+ data.repartition(1).write.parquet(dir.getCanonicalPath)
+ val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head
+
+ val reader = new UnsafeRowParquetRecordReader
+ reader.initialize(file.asInstanceOf[String], null)
+ val batch = reader.resultBatch()
+ assert(reader.nextBatch())
+ assert(batch.numRows() == n)
+ var i = 0
+ while (i < n) {
+ assert(batch.column(0).getIsNull(i))
+ assert(batch.column(1).getIsNull(i))
+ assert(batch.column(2).getIsNull(i))
+ assert(batch.column(3).getIsNull(i))
+ i += 1
+ }
+ reader.close()
+ }}
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 97c5313f0feff..3ded32c450541 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation}
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -379,9 +380,62 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
// If the "c = 1" filter gets pushed down, this query will throw an exception which
// Parquet emits. This is a Parquet issue (PARQUET-389).
+ val df = sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a")
checkAnswer(
- sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a"),
- (1 to 1).map(i => Row(i, i.toString, null)))
+ df,
+ Row(1, "1", null))
+
+ // The fields "a" and "c" only exist in one Parquet file.
+ assert(df.schema("a").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+ assert(df.schema("c").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+
+ val pathThree = s"${dir.getCanonicalPath}/table3"
+ df.write.parquet(pathThree)
+
+ // We will remove the temporary metadata when writing Parquet file.
+ val schema = sqlContext.read.parquet(pathThree).schema
+ assert(schema.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField)))
+
+ val pathFour = s"${dir.getCanonicalPath}/table4"
+ val dfStruct = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
+ dfStruct.select(struct("a").as("s")).write.parquet(pathFour)
+
+ val pathFive = s"${dir.getCanonicalPath}/table5"
+ val dfStruct2 = sparkContext.parallelize(Seq((1, 1))).toDF("c", "b")
+ dfStruct2.select(struct("c").as("s")).write.parquet(pathFive)
+
+ // If the "s.c = 1" filter gets pushed down, this query will throw an exception which
+ // Parquet emits.
+ val dfStruct3 = sqlContext.read.parquet(pathFour, pathFive).filter("s.c = 1")
+ .selectExpr("s")
+ checkAnswer(dfStruct3, Row(Row(null, 1)))
+
+ // The fields "s.a" and "s.c" only exist in one Parquet file.
+ val field = dfStruct3.schema("s").dataType.asInstanceOf[StructType]
+ assert(field("a").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+ assert(field("c").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+
+ val pathSix = s"${dir.getCanonicalPath}/table6"
+ dfStruct3.write.parquet(pathSix)
+
+ // We will remove the temporary metadata when writing Parquet file.
+ val forPathSix = sqlContext.read.parquet(pathSix).schema
+ assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField)))
+
+ // sanity test: make sure optional metadata field is not wrongly set.
+ val pathSeven = s"${dir.getCanonicalPath}/table7"
+ (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathSeven)
+ val pathEight = s"${dir.getCanonicalPath}/table8"
+ (4 to 6).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathEight)
+
+ val df2 = sqlContext.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b")
+ checkAnswer(
+ df2,
+ Row(1, "1"))
+
+ // The fields "a" and "b" exist in both two Parquet files. No metadata is set.
+ assert(!df2.schema("a").metadata.contains(StructType.metadataKeyForOptionalField))
+ assert(!df2.schema("b").metadata.contains(StructType.metadataKeyForOptionalField))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
index 60fa81b1ab819..90e3d50714ef3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
@@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.parquet.schema.MessageTypeParser
+import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -260,7 +261,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest {
int96AsTimestamp = true,
writeLegacyParquetFormat = true)
- testSchemaInference[Tuple1[Pair[Int, String]]](
+ testSchemaInference[Tuple1[(Int, String)]](
"struct",
"""
|message root {
@@ -449,6 +450,35 @@ class ParquetSchemaSuite extends ParquetSchemaTest {
}.getMessage.contains("detected conflicting schemas"))
}
+ test("schema merging failure error message") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ sqlContext.range(3).write.parquet(s"$path/p=1")
+ sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2")
+
+ val message = intercept[SparkException] {
+ sqlContext.read.option("mergeSchema", "true").parquet(path).schema
+ }.getMessage
+
+ assert(message.contains("Failed merging schema of file"))
+ }
+
+ // test for second merging (after read Parquet schema in parallel done)
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ sqlContext.range(3).write.parquet(s"$path/p=1")
+ sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2")
+
+ sqlContext.sparkContext.conf.set("spark.default.parallelism", "20")
+
+ val message = intercept[SparkException] {
+ sqlContext.read.option("mergeSchema", "true").parquet(path).schema
+ }.getMessage
+
+ assert(message.contains("Failed merging schema:"))
+ }
+ }
+
// =======================================================
// Tests for converting Parquet LIST to Catalyst ArrayType
// =======================================================
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index cbae19ebd269d..2260e4870299a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraph
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{JsonProtocol, Utils}
class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
@@ -335,23 +335,47 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
test("save metrics") {
withTempPath { file =>
- val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
- // Assume the execution plan is
- // PhysicalRDD(nodeId = 0)
- person.select('name).write.format("json").save(file.getAbsolutePath)
- sparkContext.listenerBus.waitUntilEmpty(10000)
- val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
- assert(executionIds.size === 1)
- val executionId = executionIds.head
- val jobs = sqlContext.listener.getExecution(executionId).get.jobs
- // Use "<=" because there is a race condition that we may miss some jobs
- // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event.
- assert(jobs.size <= 1)
- val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
- // Because "save" will create a new DataFrame internally, we cannot get the real metric id.
- // However, we still can check the value.
- assert(metricValues.values.toSeq === Seq("2"))
+ withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
+ val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
+ // Assume the execution plan is
+ // PhysicalRDD(nodeId = 0)
+ person.select('name).write.format("json").save(file.getAbsolutePath)
+ sparkContext.listenerBus.waitUntilEmpty(10000)
+ val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
+ assert(executionIds.size === 1)
+ val executionId = executionIds.head
+ val jobs = sqlContext.listener.getExecution(executionId).get.jobs
+ // Use "<=" because there is a race condition that we may miss some jobs
+ // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event.
+ assert(jobs.size <= 1)
+ val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
+ // Because "save" will create a new DataFrame internally, we cannot get the real metric id.
+ // However, we still can check the value.
+ assert(metricValues.values.toSeq === Seq("2"))
+ }
+ }
+ }
+
+ test("metrics can be loaded by history server") {
+ val metric = new LongSQLMetric("zanzibar", LongSQLMetricParam)
+ metric += 10L
+ val metricInfo = metric.toInfo(Some(metric.localValue), None)
+ metricInfo.update match {
+ case Some(v: LongSQLMetricValue) => assert(v.value === 10L)
+ case Some(v) => fail(s"metric value was not a LongSQLMetricValue: ${v.getClass.getName}")
+ case _ => fail("metric update is missing")
+ }
+ assert(metricInfo.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER))
+ // After serializing to JSON, the original value type is lost, but we can still
+ // identify that it's a SQL metric from the metadata
+ val metricInfoJson = JsonProtocol.accumulableInfoToJson(metricInfo)
+ val metricInfoDeser = JsonProtocol.accumulableInfoFromJson(metricInfoJson)
+ metricInfoDeser.update match {
+ case Some(v: String) => assert(v.toLong === 10L)
+ case Some(v) => fail(s"deserialized metric value was not a string: ${v.getClass.getName}")
+ case _ => fail("deserialized metric update is missing")
}
+ assert(metricInfoDeser.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 2c408c8878470..085e4a49a57e6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -26,8 +26,9 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution}
-import org.apache.spark.sql.execution.metric.LongSQLMetricValue
+import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.ui.SparkUI
class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
@@ -335,8 +336,43 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1)
}
+ test("SPARK-13055: history listener only tracks SQL metrics") {
+ val listener = new SQLHistoryListener(sparkContext.conf, mock(classOf[SparkUI]))
+ // We need to post other events for the listener to track our accumulators.
+ // These are largely just boilerplate unrelated to what we're trying to test.
+ val df = createTestDataFrame
+ val executionStart = SparkListenerSQLExecutionStart(
+ 0, "", "", "", SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), 0)
+ val stageInfo = createStageInfo(0, 0)
+ val jobStart = SparkListenerJobStart(0, 0, Seq(stageInfo), createProperties(0))
+ val stageSubmitted = SparkListenerStageSubmitted(stageInfo)
+ // This task has both accumulators that are SQL metrics and accumulators that are not.
+ // The listener should only track the ones that are actually SQL metrics.
+ val sqlMetric = SQLMetrics.createLongMetric(sparkContext, "beach umbrella")
+ val nonSqlMetric = sparkContext.accumulator[Int](0, "baseball")
+ val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.localValue), None)
+ val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.localValue), None)
+ val taskInfo = createTaskInfo(0, 0)
+ taskInfo.accumulables ++= Seq(sqlMetricInfo, nonSqlMetricInfo)
+ val taskEnd = SparkListenerTaskEnd(0, 0, "just-a-task", null, taskInfo, null)
+ listener.onOtherEvent(executionStart)
+ listener.onJobStart(jobStart)
+ listener.onStageSubmitted(stageSubmitted)
+ // Before SPARK-13055, this throws ClassCastException because the history listener would
+ // assume that the accumulator value is of type Long, but this may not be true for
+ // accumulators that are not SQL metrics.
+ listener.onTaskEnd(taskEnd)
+ val trackedAccums = listener.stageIdToStageMetrics.values.flatMap { stageMetrics =>
+ stageMetrics.taskIdToMetricUpdates.values.flatMap(_.accumulatorUpdates)
+ }
+ // Listener tracks only SQL metrics, not other accumulators
+ assert(trackedAccums.size === 1)
+ assert(trackedAccums.head === sqlMetricInfo)
+ }
+
}
+
class SQLListenerMemoryLeakSuite extends SparkFunSuite {
test("no memory leak") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 215ca9ab6b770..445f311107e33 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.types.CalendarInterval
class ColumnarBatchSuite extends SparkFunSuite {
test("Null Apis") {
@@ -439,10 +440,10 @@ class ColumnarBatchSuite extends SparkFunSuite {
c2.putDouble(1, 5.67)
val s = column.getStruct(0)
- assert(s.fields(0).getInt(0) == 123)
- assert(s.fields(0).getInt(1) == 456)
- assert(s.fields(1).getDouble(0) == 3.45)
- assert(s.fields(1).getDouble(1) == 5.67)
+ assert(s.columns()(0).getInt(0) == 123)
+ assert(s.columns()(0).getInt(1) == 456)
+ assert(s.columns()(1).getDouble(0) == 3.45)
+ assert(s.columns()(1).getDouble(1) == 5.67)
assert(s.getInt(0) == 123)
assert(s.getDouble(1) == 3.45)
@@ -571,7 +572,6 @@ class ColumnarBatchSuite extends SparkFunSuite {
}}
}
-
private def doubleEquals(d1: Double, d2: Double): Boolean = {
if (d1.isNaN && d2.isNaN) {
true
@@ -585,13 +585,23 @@ class ColumnarBatchSuite extends SparkFunSuite {
assert(r1.isNullAt(v._2) == r2.isNullAt(v._2), "Seed = " + seed)
if (!r1.isNullAt(v._2)) {
v._1.dataType match {
+ case BooleanType => assert(r1.getBoolean(v._2) == r2.getBoolean(v._2), "Seed = " + seed)
case ByteType => assert(r1.getByte(v._2) == r2.getByte(v._2), "Seed = " + seed)
+ case ShortType => assert(r1.getShort(v._2) == r2.getShort(v._2), "Seed = " + seed)
case IntegerType => assert(r1.getInt(v._2) == r2.getInt(v._2), "Seed = " + seed)
case LongType => assert(r1.getLong(v._2) == r2.getLong(v._2), "Seed = " + seed)
+ case FloatType => assert(doubleEquals(r1.getFloat(v._2), r2.getFloat(v._2)),
+ "Seed = " + seed)
case DoubleType => assert(doubleEquals(r1.getDouble(v._2), r2.getDouble(v._2)),
"Seed = " + seed)
+ case t: DecimalType =>
+ val d1 = r1.getDecimal(v._2, t.precision, t.scale).toBigDecimal
+ val d2 = r2.getDecimal(v._2)
+ assert(d1.compare(d2) == 0, "Seed = " + seed)
case StringType =>
assert(r1.getString(v._2) == r2.getString(v._2), "Seed = " + seed)
+ case CalendarIntervalType =>
+ assert(r1.getInterval(v._2) === r2.get(v._2).asInstanceOf[CalendarInterval])
case ArrayType(childType, n) =>
val a1 = r1.getArray(v._2).array
val a2 = r2.getList(v._2).toArray
@@ -605,6 +615,27 @@ class ColumnarBatchSuite extends SparkFunSuite {
i += 1
}
}
+ case FloatType => {
+ var i = 0
+ while (i < a1.length) {
+ assert(doubleEquals(a1(i).asInstanceOf[Float], a2(i).asInstanceOf[Float]),
+ "Seed = " + seed)
+ i += 1
+ }
+ }
+
+ case t: DecimalType =>
+ var i = 0
+ while (i < a1.length) {
+ assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed)
+ if (a1(i) != null) {
+ val d1 = a1(i).asInstanceOf[Decimal].toBigDecimal
+ val d2 = a2(i).asInstanceOf[java.math.BigDecimal]
+ assert(d1.compare(d2) == 0, "Seed = " + seed)
+ }
+ i += 1
+ }
+
case _ => assert(a1 === a2, "Seed = " + seed)
}
case StructType(childFields) =>
@@ -644,10 +675,13 @@ class ColumnarBatchSuite extends SparkFunSuite {
* results.
*/
def testRandomRows(flatSchema: Boolean, numFields: Int) {
- // TODO: add remaining types. Figure out why StringType doesn't work on jenkins.
- val types = Array(ByteType, IntegerType, LongType, DoubleType)
+ // TODO: Figure out why StringType doesn't work on jenkins.
+ val types = Array(
+ BooleanType, ByteType, FloatType, DoubleType,
+ IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10),
+ CalendarIntervalType)
val seed = System.nanoTime()
- val NUM_ROWS = 500
+ val NUM_ROWS = 200
val NUM_ITERS = 1000
val random = new Random(seed)
var i = 0
@@ -682,7 +716,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
}
test("Random flat schema") {
- testRandomRows(true, 10)
+ testRandomRows(true, 15)
}
test("Random nested schema") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
index 6fc9febe49707..cb88a1c83c999 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
@@ -22,7 +22,6 @@ import java.io.{File, IOException}
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.execution.datasources.DDLException
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
@@ -105,7 +104,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with
sql("SELECT a, b FROM jsonTable"),
sql("SELECT a, b FROM jt").collect())
- val message = intercept[DDLException]{
+ val message = intercept[AnalysisException]{
sql(
s"""
|CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable
@@ -156,7 +155,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with
}
test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") {
- val message = intercept[DDLException]{
+ val message = intercept[AnalysisException]{
sql(
s"""
|CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable
@@ -173,7 +172,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with
}
test("a CTAS statement with column definitions is not allowed") {
- intercept[DDLException]{
+ intercept[AnalysisException]{
sql(
s"""
|CREATE TEMPORARY TABLE jsonTable (a int, b string)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
new file mode 100644
index 0000000000000..36212e4395985
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming.test
+
+import org.apache.spark.sql.{AnalysisException, SQLContext, StreamTest}
+import org.apache.spark.sql.execution.streaming.{Batch, Offset, Sink, Source}
+import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+
+object LastOptions {
+ var parameters: Map[String, String] = null
+ var schema: Option[StructType] = null
+ var partitionColumns: Seq[String] = Nil
+}
+
+/** Dummy provider: returns no-op source/sink and records options in [[LastOptions]]. */
+class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
+ override def createSource(
+ sqlContext: SQLContext,
+ parameters: Map[String, String],
+ schema: Option[StructType]): Source = {
+ LastOptions.parameters = parameters
+ LastOptions.schema = schema
+ new Source {
+ override def getNextBatch(start: Option[Offset]): Option[Batch] = None
+ override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil)
+ }
+ }
+
+ override def createSink(
+ sqlContext: SQLContext,
+ parameters: Map[String, String],
+ partitionColumns: Seq[String]): Sink = {
+ LastOptions.parameters = parameters
+ LastOptions.partitionColumns = partitionColumns
+ new Sink {
+ override def addBatch(batch: Batch): Unit = {}
+ override def currentOffset: Option[Offset] = None
+ }
+ }
+}
+
+class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext {
+ import testImplicits._
+
+ test("resolve default source") {
+ sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ .write
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ .stop()
+ }
+
+ test("resolve full class") {
+ sqlContext.read
+ .format("org.apache.spark.sql.streaming.test.DefaultSource")
+ .stream()
+ .write
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ .stop()
+ }
+
+ test("options") {
+ val map = new java.util.HashMap[String, String]
+ map.put("opt3", "3")
+
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .option("opt1", "1")
+ .options(Map("opt2" -> "2"))
+ .options(map)
+ .stream()
+
+ assert(LastOptions.parameters("opt1") == "1")
+ assert(LastOptions.parameters("opt2") == "2")
+ assert(LastOptions.parameters("opt3") == "3")
+
+ LastOptions.parameters = null
+
+ df.write
+ .format("org.apache.spark.sql.streaming.test")
+ .option("opt1", "1")
+ .options(Map("opt2" -> "2"))
+ .options(map)
+ .stream()
+ .stop()
+
+ assert(LastOptions.parameters("opt1") == "1")
+ assert(LastOptions.parameters("opt2") == "2")
+ assert(LastOptions.parameters("opt3") == "3")
+ }
+
+ test("partitioning") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+
+ df.write
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ .stop()
+ assert(LastOptions.partitionColumns == Nil)
+
+ df.write
+ .format("org.apache.spark.sql.streaming.test")
+ .partitionBy("a")
+ .stream()
+ .stop()
+ assert(LastOptions.partitionColumns == Seq("a"))
+
+ withSQLConf("spark.sql.caseSensitive" -> "false") {
+ df.write
+ .format("org.apache.spark.sql.streaming.test")
+ .partitionBy("A")
+ .stream()
+ .stop()
+ assert(LastOptions.partitionColumns == Seq("a"))
+ }
+
+ intercept[AnalysisException] {
+ df.write
+ .format("org.apache.spark.sql.streaming.test")
+ .partitionBy("b")
+ .stream()
+ .stop()
+ }
+ }
+
+ test("stream paths") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream("/test")
+
+ assert(LastOptions.parameters("path") == "/test")
+
+ LastOptions.parameters = null
+
+ df.write
+ .format("org.apache.spark.sql.streaming.test")
+ .stream("/test")
+ .stop()
+
+ assert(LastOptions.parameters("path") == "/test")
+ }
+
+ test("test different data types for options") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .option("intOpt", 56)
+ .option("boolOpt", false)
+ .option("doubleOpt", 6.7)
+ .stream("/test")
+
+ assert(LastOptions.parameters("intOpt") == "56")
+ assert(LastOptions.parameters("boolOpt") == "false")
+ assert(LastOptions.parameters("doubleOpt") == "6.7")
+
+ LastOptions.parameters = null
+ df.write
+ .format("org.apache.spark.sql.streaming.test")
+ .option("intOpt", 56)
+ .option("boolOpt", false)
+ .option("doubleOpt", 6.7)
+ .stream("/test")
+ .stop()
+
+ assert(LastOptions.parameters("intOpt") == "56")
+ assert(LastOptions.parameters("boolOpt") == "false")
+ assert(LastOptions.parameters("doubleOpt") == "6.7")
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala
new file mode 100644
index 0000000000000..81760d2aa8205
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.sql.StreamTest
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.test.SharedSQLContext
+
+class MemorySourceStressSuite extends StreamTest with SharedSQLContext {
+ import testImplicits._
+
+ test("memory stress test") {
+ val input = MemoryStream[Int]
+ val mapped = input.toDS().map(_ + 1)
+
+ runStressTest(mapped, AddData(input, _: _*))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala
new file mode 100644
index 0000000000000..989465826d54e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset, Offset}
+
+trait OffsetSuite extends SparkFunSuite {
+ /** Creates test to check all the comparisons of offsets given a `one` that is less than `two`. */
+ def compare(one: Offset, two: Offset): Unit = {
+ test(s"comparision $one <=> $two") {
+ assert(one < two)
+ assert(one <= two)
+ assert(one <= one)
+ assert(two > one)
+ assert(two >= one)
+ assert(one >= one)
+ assert(one == one)
+ assert(two == two)
+ assert(one != two)
+ assert(two != one)
+ }
+ }
+
+ /** Creates test to check that non-equality comparisons throw exception. */
+ def compareInvalid(one: Offset, two: Offset): Unit = {
+ test(s"invalid comparison $one <=> $two") {
+ intercept[IllegalArgumentException] {
+ assert(one < two)
+ }
+
+ intercept[IllegalArgumentException] {
+ assert(one <= two)
+ }
+
+ intercept[IllegalArgumentException] {
+ assert(one > two)
+ }
+
+ intercept[IllegalArgumentException] {
+ assert(one >= two)
+ }
+
+ assert(!(one == two))
+ assert(!(two == one))
+ assert(one != two)
+ assert(two != one)
+ }
+ }
+}
+
+class LongOffsetSuite extends OffsetSuite {
+ val one = LongOffset(1)
+ val two = LongOffset(2)
+ compare(one, two)
+}
+
+class CompositeOffsetSuite extends OffsetSuite {
+ compare(
+ one = CompositeOffset(Some(LongOffset(1)) :: Nil),
+ two = CompositeOffset(Some(LongOffset(2)) :: Nil))
+
+ compare(
+ one = CompositeOffset(None :: Nil),
+ two = CompositeOffset(Some(LongOffset(2)) :: Nil))
+
+ compareInvalid( // sizes must be same
+ one = CompositeOffset(Nil),
+ two = CompositeOffset(Some(LongOffset(2)) :: Nil))
+
+ compare(
+ one = CompositeOffset.fill(LongOffset(0), LongOffset(1)),
+ two = CompositeOffset.fill(LongOffset(1), LongOffset(2)))
+
+ compare(
+ one = CompositeOffset.fill(LongOffset(1), LongOffset(1)),
+ two = CompositeOffset.fill(LongOffset(1), LongOffset(2)))
+
+ compareInvalid(
+ one = CompositeOffset.fill(LongOffset(2), LongOffset(1)), // vector time inconsistent
+ two = CompositeOffset.fill(LongOffset(1), LongOffset(2)))
+}
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
new file mode 100644
index 0000000000000..fbb1792596b18
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.sql.{Row, StreamTest}
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.test.SharedSQLContext
+
+class StreamSuite extends StreamTest with SharedSQLContext {
+
+ import testImplicits._
+
+ test("map with recovery") {
+ val inputData = MemoryStream[Int]
+ val mapped = inputData.toDS().map(_ + 1)
+
+ testStream(mapped)(
+ AddData(inputData, 1, 2, 3),
+ StartStream,
+ CheckAnswer(2, 3, 4),
+ StopStream,
+ AddData(inputData, 4, 5, 6),
+ StartStream,
+ CheckAnswer(2, 3, 4, 5, 6, 7))
+ }
+
+ test("join") {
+ // Make a table and ensure it will be broadcast.
+ val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word")
+
+ // Join the input stream with a table.
+ val inputData = MemoryStream[Int]
+ val joined = inputData.toDS().toDF().join(smallTable, $"value" === $"number")
+
+ testStream(joined)(
+ AddData(inputData, 1, 2, 3),
+ CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two")),
+ AddData(inputData, 4),
+ CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four")))
+ }
+
+ test("union two streams") {
+ val inputData1 = MemoryStream[Int]
+ val inputData2 = MemoryStream[Int]
+
+ val unioned = inputData1.toDS().union(inputData2.toDS())
+
+ testStream(unioned)(
+ AddData(inputData1, 1, 3, 5),
+ CheckAnswer(1, 3, 5),
+ AddData(inputData2, 2, 4, 6),
+ CheckAnswer(1, 2, 3, 4, 5, 6),
+ StopStream,
+ AddData(inputData1, 7),
+ StartStream,
+ AddData(inputData2, 8),
+ CheckAnswer(1, 2, 3, 4, 5, 6, 7, 8))
+ }
+
+ test("sql queries") {
+ val inputData = MemoryStream[Int]
+ inputData.toDF().registerTempTable("stream")
+ val evens = sql("SELECT * FROM stream WHERE value % 2 = 0")
+
+ testStream(evens)(
+ AddData(inputData, 1, 2, 3, 4),
+ CheckAnswer(2, 4))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index d48143762cac0..7d6bff8295d2b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils
val schema = df.schema
val childRDD = df
.queryExecution
- .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
+ .sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
.child
.execute()
.map(row => Row.fromSeq(row.copy().toSeq(schema)))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
index e7b376548787c..c341191c70bb5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
@@ -36,7 +36,7 @@ trait SharedSQLContext extends SQLTestUtils {
/**
* The [[TestSQLContext]] to use for all tests in this suite.
*/
- protected def sqlContext: SQLContext = _ctx
+ protected implicit def sqlContext: SQLContext = _ctx
/**
* Initialize the [[TestSQLContext]].
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index 9a24a2487a254..a3e5243b68aba 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -97,10 +97,12 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
}
sqlContext.listenerManager.register(listener)
- val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
- df.collect()
- df.collect()
- Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
+ withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
+ val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
+ df.collect()
+ df.collect()
+ Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
+ }
assert(metrics.length == 3)
assert(metrics(0) == 1)
diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
index 435e565f63458..c8d17bd468582 100644
--- a/sql/hive-thriftserver/pom.xml
+++ b/sql/hive-thriftserver/pom.xml
@@ -21,13 +21,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-hive-thriftserver_2.10
+ spark-hive-thriftserver_2.11jarSpark Project Hive Thrift Serverhttp://spark.apache.org/
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
index ab31d45a79a2e..72da266da4d01 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -183,7 +183,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {
"CREATE DATABASE hive_test_db;"
-> "OK",
"USE hive_test_db;"
- -> "OK",
+ -> "",
"CREATE TABLE hive_test(key INT, val STRING);"
-> "OK",
"SHOW TABLES;"
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
index ba3b26e1b7d49..865197e24caf8 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
@@ -23,7 +23,7 @@ import java.sql.{Date, DriverManager, SQLException, Statement}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.{future, Await, ExecutionContext, Promise}
+import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.concurrent.duration._
import scala.io.Source
import scala.util.{Random, Try}
@@ -362,7 +362,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
try {
// Start a very-long-running query that will take hours to finish, then cancel it in order
// to demonstrate that cancellation works.
- val f = future {
+ val f = Future {
statement.executeQuery(
"SELECT COUNT(*) FROM test_map " +
List.fill(10)("join test_map").mkString(" "))
@@ -380,7 +380,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
// Cancellation is a no-op if spark.sql.hive.thriftServer.async=false
statement.executeQuery("SET spark.sql.hive.thriftServer.async=false")
try {
- val sf = future {
+ val sf = Future {
statement.executeQuery(
"SELECT COUNT(*) FROM test_map " +
List.fill(4)("join test_map").mkString(" ")
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 554d47d651aef..61b73fa557144 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -325,6 +325,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"drop_partitions_ignore_protection",
"protectmode",
+ // Hive returns null rather than NaN when n = 1
+ "udaf_covar_samp",
+
// Spark parser treats numerical literals differently: it creates decimals instead of doubles.
"udf_abs",
"udf_format_number",
@@ -881,7 +884,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"type_widening",
"udaf_collect_set",
"udaf_covar_pop",
- "udaf_covar_samp",
"udaf_histogram_numeric",
"udf2",
"udf5",
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index cd0c2aeb93a9f..14cf9acf09d5b 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -21,13 +21,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-hive_2.10
+ spark-hive_2.11jarSpark Project Hivehttp://spark.apache.org/
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 1797ea54f2501..05863ae18350d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -79,8 +79,8 @@ class HiveContext private[hive](
sc: SparkContext,
cacheManager: CacheManager,
listener: SQLListener,
- @transient private val execHive: ClientWrapper,
- @transient private val metaHive: ClientInterface,
+ @transient private val execHive: HiveClientImpl,
+ @transient private val metaHive: HiveClient,
isRootContext: Boolean)
extends SQLContext(sc, cacheManager, listener, isRootContext) with Logging {
self =>
@@ -193,7 +193,7 @@ class HiveContext private[hive](
* for storing persistent metadata, and only point to a dummy metastore in a temporary directory.
*/
@transient
- protected[hive] lazy val executionHive: ClientWrapper = if (execHive != null) {
+ protected[hive] lazy val executionHive: HiveClientImpl = if (execHive != null) {
execHive
} else {
logInfo(s"Initializing execution hive, version $hiveExecutionVersion")
@@ -203,7 +203,7 @@ class HiveContext private[hive](
config = newTemporaryConfiguration(useInMemoryDerby = true),
isolationOn = false,
baseClassLoader = Utils.getContextOrSparkClassLoader)
- loader.createClient().asInstanceOf[ClientWrapper]
+ loader.createClient().asInstanceOf[HiveClientImpl]
}
/**
@@ -222,7 +222,7 @@ class HiveContext private[hive](
* in the hive-site.xml file.
*/
@transient
- protected[hive] lazy val metadataHive: ClientInterface = if (metaHive != null) {
+ protected[hive] lazy val metadataHive: HiveClient = if (metaHive != null) {
metaHive
} else {
val metaVersion = IsolatedClientLoader.hiveVersion(hiveMetastoreVersion)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index a9c0e9ab7caef..61d0d6759ff72 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -96,7 +96,7 @@ private[hive] object HiveSerDe {
}
}
-private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: HiveContext)
+private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveContext)
extends Catalog with Logging {
val conf = hive.conf
@@ -711,6 +711,10 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
}
override def unregisterAllTables(): Unit = {}
+
+ override def setCurrentDatabase(databaseName: String): Unit = {
+ client.setCurrentDatabase(databaseName)
+ }
}
/**
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 22841ed2116d1..752c037a842a8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -155,8 +155,6 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging
"TOK_SHOWLOCKS",
"TOK_SHOWPARTITIONS",
- "TOK_SWITCHDATABASE",
-
"TOK_UNLOCKTABLE"
)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
index 1654594538366..fc5725d6915ea 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
@@ -23,7 +23,7 @@ import org.apache.spark.Logging
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder}
-import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing
+import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.execution.datasources.LogicalRelation
@@ -188,7 +188,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
// The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over
// `Aggregate`s to perform type casting. This rule merges these `Project`s into
// `Aggregate`s.
- ProjectCollapsing,
+ CollapseProject,
// Used to handle other auxiliary `Project`s added by analyzer (e.g.
// `ResolveAggregateFunctions` rule)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
similarity index 94%
rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
index 9d9a55edd7314..f681cc67041a1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
@@ -60,9 +60,9 @@ private[hive] case class HiveTable(
viewText: Option[String] = None) {
@transient
- private[client] var client: ClientInterface = _
+ private[client] var client: HiveClient = _
- private[client] def withClient(ci: ClientInterface): this.type = {
+ private[client] def withClient(ci: HiveClient): this.type = {
client = ci
this
}
@@ -85,7 +85,7 @@ private[hive] case class HiveTable(
* internal and external classloaders for a given version of Hive and thus must expose only
* shared classes.
*/
-private[hive] trait ClientInterface {
+private[hive] trait HiveClient {
/** Returns the Hive Version of this client. */
def version: HiveVersion
@@ -109,6 +109,9 @@ private[hive] trait ClientInterface {
/** Returns the name of the active database. */
def currentDatabase: String
+ /** Sets the name of current database. */
+ def setCurrentDatabase(databaseName: String): Unit
+
/** Returns the metadata for specified database, throwing an exception if it doesn't exist */
def getDatabase(name: String): HiveDatabase = {
getDatabaseOption(name).getOrElse(throw new NoSuchDatabaseException)
@@ -181,8 +184,8 @@ private[hive] trait ClientInterface {
/** Add a jar into class loader */
def addJar(path: String): Unit
- /** Return a ClientInterface as new session, that will share the class loader and Hive client */
- def newSession(): ClientInterface
+ /** Return a [[HiveClient]] as new session, that will share the class loader and Hive client */
+ def newSession(): HiveClient
/** Run a function within Hive state (SessionState, HiveConf, Hive client and class loader) */
def withHiveState[A](f: => A): A
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
similarity index 96%
rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index ce7a305d437a5..cf1ff55c96fc9 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -35,6 +35,7 @@ import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader}
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.util.{CircularBuffer, Utils}
@@ -43,8 +44,8 @@ import org.apache.spark.util.{CircularBuffer, Utils}
* A class that wraps the HiveClient and converts its responses to externally visible classes.
* Note that this class is typically loaded with an internal classloader for each instantiation,
* allowing it to interact directly with a specific isolated version of Hive. Loading this class
- * with the isolated classloader however will result in it only being visible as a ClientInterface,
- * not a ClientWrapper.
+ * with the isolated classloader however will result in it only being visible as a [[HiveClient]],
+ * not a [[HiveClientImpl]].
*
* This class needs to interact with multiple versions of Hive, but will always be compiled with
* the 'native', execution version of Hive. Therefore, any places where hive breaks compatibility
@@ -54,14 +55,14 @@ import org.apache.spark.util.{CircularBuffer, Utils}
* @param config a collection of configuration options that will be added to the hive conf before
* opening the hive client.
* @param initClassLoader the classloader used when creating the `state` field of
- * this ClientWrapper.
+ * this [[HiveClientImpl]].
*/
-private[hive] class ClientWrapper(
+private[hive] class HiveClientImpl(
override val version: HiveVersion,
config: Map[String, String],
initClassLoader: ClassLoader,
val clientLoader: IsolatedClientLoader)
- extends ClientInterface
+ extends HiveClient
with Logging {
// Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur.
@@ -76,7 +77,7 @@ private[hive] class ClientWrapper(
case hive.v1_2 => new Shim_v1_2()
}
- // Create an internal session state for this ClientWrapper.
+ // Create an internal session state for this HiveClientImpl.
val state = {
val original = Thread.currentThread().getContextClassLoader
// Switch to the initClassLoader.
@@ -159,7 +160,7 @@ private[hive] class ClientWrapper(
case e: Exception if causedByThrift(e) =>
caughtException = e
logWarning(
- "HiveClientWrapper got thrift exception, destroying client and retrying " +
+ "HiveClient got thrift exception, destroying client and retrying " +
s"(${retryLimit - numTries} tries remaining)", e)
clientLoader.cachedHive = null
Thread.sleep(retryDelayMillis)
@@ -198,7 +199,7 @@ private[hive] class ClientWrapper(
*/
def withHiveState[A](f: => A): A = retryLocked {
val original = Thread.currentThread().getContextClassLoader
- // Set the thread local metastore client to the client associated with this ClientWrapper.
+ // Set the thread local metastore client to the client associated with this HiveClientImpl.
Hive.set(client)
// The classloader in clientLoader could be changed after addJar, always use the latest
// classloader
@@ -229,6 +230,14 @@ private[hive] class ClientWrapper(
state.getCurrentDatabase
}
+ override def setCurrentDatabase(databaseName: String): Unit = withHiveState {
+ if (getDatabaseOption(databaseName).isDefined) {
+ state.setCurrentDatabase(databaseName)
+ } else {
+ throw new NoSuchDatabaseException
+ }
+ }
+
override def createDatabase(database: HiveDatabase): Unit = withHiveState {
client.createDatabase(
new Database(
@@ -512,8 +521,8 @@ private[hive] class ClientWrapper(
runSqlHive(s"ADD JAR $path")
}
- def newSession(): ClientWrapper = {
- clientLoader.createClient().asInstanceOf[ClientWrapper]
+ def newSession(): HiveClientImpl = {
+ clientLoader.createClient().asInstanceOf[HiveClientImpl]
}
def reset(): Unit = withHiveState {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
index ca636b0265d41..70c10be25be9f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -38,8 +38,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{IntegralType, StringType}
/**
- * A shim that defines the interface between ClientWrapper and the underlying Hive library used to
- * talk to the metastore. Each Hive version has its own implementation of this class, defining
+ * A shim that defines the interface between [[HiveClientImpl]] and the underlying Hive library used
+ * to talk to the metastore. Each Hive version has its own implementation of this class, defining
* version-specific version of needed functions.
*
* The guideline for writing shims is:
@@ -52,7 +52,6 @@ private[client] sealed abstract class Shim {
/**
* Set the current SessionState to the given SessionState. Also, set the context classloader of
* the current thread to the one set in the HiveConf of this given `state`.
- * @param state
*/
def setCurrentSessionState(state: SessionState): Unit
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
index 010051d255fdc..dca7396ee1ab4 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
@@ -124,15 +124,15 @@ private[hive] object IsolatedClientLoader extends Logging {
}
/**
- * Creates a Hive `ClientInterface` using a classloader that works according to the following rules:
+ * Creates a [[HiveClient]] using a classloader that works according to the following rules:
* - Shared classes: Java, Scala, logging, and Spark classes are delegated to `baseClassLoader`
- * allowing the results of calls to the `ClientInterface` to be visible externally.
+ * allowing the results of calls to the [[HiveClient]] to be visible externally.
* - Hive classes: new instances are loaded from `execJars`. These classes are not
* accessible externally due to their custom loading.
- * - ClientWrapper: a new copy is created for each instance of `IsolatedClassLoader`.
+ * - [[HiveClientImpl]]: a new copy is created for each instance of `IsolatedClassLoader`.
* This new instance is able to see a specific version of hive without using reflection. Since
* this is a unique instance, it is not visible externally other than as a generic
- * `ClientInterface`, unless `isolationOn` is set to `false`.
+ * [[HiveClient]], unless `isolationOn` is set to `false`.
*
* @param version The version of hive on the classpath. used to pick specific function signatures
* that are not compatible across versions.
@@ -179,7 +179,7 @@ private[hive] class IsolatedClientLoader(
/** True if `name` refers to a spark class that must see specific version of Hive. */
protected def isBarrierClass(name: String): Boolean =
- name.startsWith(classOf[ClientWrapper].getName) ||
+ name.startsWith(classOf[HiveClientImpl].getName) ||
name.startsWith(classOf[Shim].getName) ||
barrierPrefixes.exists(name.startsWith)
@@ -233,9 +233,9 @@ private[hive] class IsolatedClientLoader(
}
/** The isolated client interface to Hive. */
- private[hive] def createClient(): ClientInterface = {
+ private[hive] def createClient(): HiveClient = {
if (!isolationOn) {
- return new ClientWrapper(version, config, baseClassLoader, this)
+ return new HiveClientImpl(version, config, baseClassLoader, this)
}
// Pre-reflective instantiation setup.
logDebug("Initializing the logger to avoid disaster...")
@@ -244,10 +244,10 @@ private[hive] class IsolatedClientLoader(
try {
classLoader
- .loadClass(classOf[ClientWrapper].getName)
+ .loadClass(classOf[HiveClientImpl].getName)
.getConstructors.head
.newInstance(version, config, classLoader, this)
- .asInstanceOf[ClientInterface]
+ .asInstanceOf[HiveClient]
} catch {
case e: InvocationTargetException =>
if (e.getCause().isInstanceOf[NoClassDefFoundError]) {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 56cab1aee89df..d5ed838ca4b1a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -38,13 +38,13 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.sequenceOption
import org.apache.spark.sql.hive.HiveShim._
-import org.apache.spark.sql.hive.client.ClientWrapper
+import org.apache.spark.sql.hive.client.HiveClientImpl
import org.apache.spark.sql.types._
private[hive] class HiveFunctionRegistry(
underlying: analysis.FunctionRegistry,
- executionHive: ClientWrapper)
+ executionHive: HiveClientImpl)
extends analysis.FunctionRegistry with HiveInspectors {
def getFunctionInfo(name: String): FunctionInfo = {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index a33223af24370..246108e0d0e11 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.CacheTableCommand
import org.apache.spark.sql.hive._
-import org.apache.spark.sql.hive.client.ClientWrapper
+import org.apache.spark.sql.hive.client.HiveClientImpl
import org.apache.spark.sql.hive.execution.HiveNativeCommand
import org.apache.spark.util.{ShutdownHookManager, Utils}
@@ -458,7 +458,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
org.apache.spark.sql.catalyst.analysis.FunctionRegistry.builtin.copy(), this.executionHive)
}
-private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: ClientWrapper)
+private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: HiveClientImpl)
extends HiveFunctionRegistry(fr, client) {
private val removedFunctions =
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
index ff10a251f3b45..1344a2cc4bd37 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.tags.ExtendedHiveTest
import org.apache.spark.util.Utils
/**
- * A simple set of tests that call the methods of a hive ClientInterface, loading different version
+ * A simple set of tests that call the methods of a [[HiveClient]], loading different version
* of hive from maven central. These tests are simple in that they are mostly just testing to make
* sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality
* is not fully tested.
@@ -101,7 +101,7 @@ class VersionsSuite extends SparkFunSuite with Logging {
private val versions = Seq("12", "13", "14", "1.0.0", "1.1.0", "1.2.0")
- private var client: ClientInterface = null
+ private var client: HiveClient = null
versions.foreach { version =>
test(s"$version: create client") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 3e4cf3f79e57c..caf1db9ad0855 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -193,6 +193,14 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
sqlContext.dropTempTable("emptyTable")
}
+ test("group by function") {
+ Seq((1, 2)).toDF("a", "b").registerTempTable("data")
+
+ checkAnswer(
+ sql("SELECT floor(a) AS a, collect_set(b) FROM data GROUP BY floor(a) ORDER BY a"),
+ Row(1, Array(2)) :: Nil)
+ }
+
test("empty table") {
// If there is no GROUP BY clause and the table is empty, we will generate a single row.
checkAnswer(
@@ -790,7 +798,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
"""
|SELECT corr(b, c) FROM covar_tab WHERE a = 3
""".stripMargin),
- Row(null) :: Nil)
+ Row(Double.NaN) :: Nil)
checkAnswer(
sqlContext.sql(
@@ -799,10 +807,10 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
""".stripMargin),
Row(1, null) ::
Row(2, null) ::
- Row(3, null) ::
- Row(4, null) ::
- Row(5, null) ::
- Row(6, null) :: Nil)
+ Row(3, Double.NaN) ::
+ Row(4, Double.NaN) ::
+ Row(5, Double.NaN) ::
+ Row(6, Double.NaN) :: Nil)
val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0)
assert(math.abs(corr7 - 0.6633880657639323) < 1e-12)
@@ -833,11 +841,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
// one row test
val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b")
- val cov_samp3 = df3.groupBy().agg(covar_samp("a", "b")).collect()(0).get(0)
- assert(cov_samp3 == null)
-
- val cov_pop3 = df3.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0)
- assert(cov_pop3 == 0.0)
+ checkAnswer(df3.groupBy().agg(covar_samp("a", "b")), Row(Double.NaN))
+ checkAnswer(df3.groupBy().agg(covar_pop("a", "b")), Row(0.0))
}
test("no aggregation function (SPARK-11486)") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 4659d745fe78b..1337a25eb26a3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -28,6 +28,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.{SparkException, SparkFiles}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
+import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin
@@ -769,14 +770,14 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") {
val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3))
- .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)}
+ .zipWithIndex.map {case ((value, attr), key) => HavingRow(key, value, attr)}
TestHive.sparkContext.parallelize(fixture).toDF().registerTempTable("having_test")
val results =
sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3")
.collect()
- .map(x => Pair(x.getString(0), x.getInt(1)))
+ .map(x => (x.getString(0), x.getInt(1)))
- assert(results === Array(Pair("foo", 4)))
+ assert(results === Array(("foo", 4)))
TestHive.reset()
}
@@ -1262,6 +1263,21 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
}
+ test("use database") {
+ val currentDatabase = sql("select current_database()").first().getString(0)
+
+ sql("CREATE DATABASE hive_test_db")
+ sql("USE hive_test_db")
+ assert("hive_test_db" == sql("select current_database()").first().getString(0))
+
+ intercept[NoSuchDatabaseException] {
+ sql("USE not_existing_db")
+ }
+
+ sql(s"USE $currentDatabase")
+ assert(currentDatabase == sql("select current_database()").first().getString(0))
+ }
+
test("lookup hive UDF in another thread") {
val e = intercept[AnalysisException] {
range(1).selectExpr("not_a_udf()")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 0d62d799c8dce..6048b8f5a3998 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -199,7 +199,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
"Extended Usage")
checkExistence(sql("describe functioN abcadf"), true,
- "Function: abcadf is not found.")
+ "Function: abcadf not found.")
checkExistence(sql("describe functioN `~`"), true,
"Function: ~",
@@ -736,7 +736,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
""".stripMargin), (2 to 6).map(i => Row(i)))
}
- test("window function: udaf with aggregate expressin") {
+ test("window function: udaf with aggregate expression") {
val data = Seq(
WindowData(1, "a", 5),
WindowData(2, "a", 6),
@@ -927,6 +927,88 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
).map(i => Row(i._1, i._2, i._3, i._4)))
}
+ test("window function: Sorting columns are not in Project") {
+ val data = Seq(
+ WindowData(1, "d", 10),
+ WindowData(2, "a", 6),
+ WindowData(3, "b", 7),
+ WindowData(4, "b", 8),
+ WindowData(5, "c", 9),
+ WindowData(6, "c", 11)
+ )
+ sparkContext.parallelize(data).toDF().registerTempTable("windowData")
+
+ checkAnswer(
+ sql("select month, product, sum(product + 1) over() from windowData order by area"),
+ Seq(
+ (2, 6, 57),
+ (3, 7, 57),
+ (4, 8, 57),
+ (5, 9, 57),
+ (6, 11, 57),
+ (1, 10, 57)
+ ).map(i => Row(i._1, i._2, i._3)))
+
+ checkAnswer(
+ sql(
+ """
+ |select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1
+ |from (select month, area, product as p, 1 as tmp1 from windowData) tmp order by p
+ """.stripMargin),
+ Seq(
+ ("a", 2),
+ ("b", 2),
+ ("b", 3),
+ ("c", 2),
+ ("d", 2),
+ ("c", 3)
+ ).map(i => Row(i._1, i._2)))
+
+ checkAnswer(
+ sql(
+ """
+ |select area, rank() over (partition by area order by month) as c1
+ |from windowData group by product, area, month order by product, area
+ """.stripMargin),
+ Seq(
+ ("a", 1),
+ ("b", 1),
+ ("b", 2),
+ ("c", 1),
+ ("d", 1),
+ ("c", 2)
+ ).map(i => Row(i._1, i._2)))
+ }
+
+ // todo: fix this test case by reimplementing the function ResolveAggregateFunctions
+ ignore("window function: Pushing aggregate Expressions in Sort to Aggregate") {
+ val data = Seq(
+ WindowData(1, "d", 10),
+ WindowData(2, "a", 6),
+ WindowData(3, "b", 7),
+ WindowData(4, "b", 8),
+ WindowData(5, "c", 9),
+ WindowData(6, "c", 11)
+ )
+ sparkContext.parallelize(data).toDF().registerTempTable("windowData")
+
+ checkAnswer(
+ sql(
+ """
+ |select area, sum(product) over () as c from windowData
+ |where product > 3 group by area, product
+ |having avg(month) > 0 order by avg(month), product
+ """.stripMargin),
+ Seq(
+ ("a", 51),
+ ("b", 51),
+ ("b", 51),
+ ("c", 51),
+ ("c", 51),
+ ("d", 51)
+ ).map(i => Row(i._1, i._2)))
+ }
+
test("window function: multiple window expressions in a single expression") {
val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
nums.registerTempTable("nums")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 150d0c748631e..9ba645626fe72 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -19,22 +19,28 @@ package org.apache.spark.sql.sources
import java.io.File
-import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLConf}
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
-import org.apache.spark.sql.execution.Exchange
-import org.apache.spark.sql.execution.datasources.BucketSpec
+import org.apache.spark.sql.execution.{Exchange, PhysicalRDD}
+import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSourceStrategy}
import org.apache.spark.sql.execution.joins.SortMergeJoin
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.util.Utils
+import org.apache.spark.util.collection.BitSet
class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._
+ private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
+ private val nullDF = (for {
+ i <- 0 to 50
+ s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g")
+ } yield (i % 5, s, i % 13)).toDF("i", "j", "k")
+
test("read bucketed data") {
- val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
withTable("bucketed_table") {
df.write
.format("parquet")
@@ -59,6 +65,152 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
}
}
+ // To verify if the bucket pruning works, this function checks two conditions:
+ // 1) Check if the pruned buckets (before filtering) are empty.
+ // 2) Verify the final result is the same as the expected one
+ private def checkPrunedAnswers(
+ bucketSpec: BucketSpec,
+ bucketValues: Seq[Integer],
+ filterCondition: Column,
+ originalDataFrame: DataFrame): Unit = {
+
+ val bucketedDataFrame = hiveContext.table("bucketed_table").select("i", "j", "k")
+ val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec
+ // Limit: bucket pruning only works when the bucket column has one and only one column
+ assert(bucketColumnNames.length == 1)
+ val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head)
+ val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex)
+ val matchedBuckets = new BitSet(numBuckets)
+ bucketValues.foreach { value =>
+ matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value))
+ }
+
+ // Filter could hide the bug in bucket pruning. Thus, skipping all the filters
+ val rdd = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan
+ .find(_.isInstanceOf[PhysicalRDD])
+ assert(rdd.isDefined)
+
+ val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) =>
+ if (matchedBuckets.get(index % numBuckets)) Iterator(true) else Iterator(iter.isEmpty)
+ }
+ // checking if all the pruned buckets are empty
+ assert(checkedResult.collect().forall(_ == true))
+
+ checkAnswer(
+ bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"),
+ originalDataFrame.filter(filterCondition).orderBy("i", "j", "k"))
+ }
+
+ test("read partitioning bucketed tables with bucket pruning filters") {
+ withTable("bucketed_table") {
+ val numBuckets = 8
+ val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
+ // json does not support predicate push-down, and thus json is used here
+ df.write
+ .format("json")
+ .partitionBy("i")
+ .bucketBy(numBuckets, "j")
+ .saveAsTable("bucketed_table")
+
+ for (j <- 0 until 13) {
+ // Case 1: EqualTo
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = j :: Nil,
+ filterCondition = $"j" === j,
+ df)
+
+ // Case 2: EqualNullSafe
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = j :: Nil,
+ filterCondition = $"j" <=> j,
+ df)
+
+ // Case 3: In
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = Seq(j, j + 1, j + 2, j + 3),
+ filterCondition = $"j".isin(j, j + 1, j + 2, j + 3),
+ df)
+ }
+ }
+ }
+
+ test("read non-partitioning bucketed tables with bucket pruning filters") {
+ withTable("bucketed_table") {
+ val numBuckets = 8
+ val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
+ // json does not support predicate push-down, and thus json is used here
+ df.write
+ .format("json")
+ .bucketBy(numBuckets, "j")
+ .saveAsTable("bucketed_table")
+
+ for (j <- 0 until 13) {
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = j :: Nil,
+ filterCondition = $"j" === j,
+ df)
+ }
+ }
+ }
+
+ test("read partitioning bucketed tables having null in bucketing key") {
+ withTable("bucketed_table") {
+ val numBuckets = 8
+ val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
+ // json does not support predicate push-down, and thus json is used here
+ nullDF.write
+ .format("json")
+ .partitionBy("i")
+ .bucketBy(numBuckets, "j")
+ .saveAsTable("bucketed_table")
+
+ // Case 1: isNull
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = null :: Nil,
+ filterCondition = $"j".isNull,
+ nullDF)
+
+ // Case 2: <=> null
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = null :: Nil,
+ filterCondition = $"j" <=> null,
+ nullDF)
+ }
+ }
+
+ test("read partitioning bucketed tables having composite filters") {
+ withTable("bucketed_table") {
+ val numBuckets = 8
+ val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
+ // json does not support predicate push-down, and thus json is used here
+ df.write
+ .format("json")
+ .partitionBy("i")
+ .bucketBy(numBuckets, "j")
+ .saveAsTable("bucketed_table")
+
+ for (j <- 0 until 13) {
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = j :: Nil,
+ filterCondition = $"j" === j && $"k" > $"j",
+ df)
+
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = j :: Nil,
+ filterCondition = $"j" === j && $"i" > j % 5,
+ df)
+ }
+ }
+ }
+
private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2")
diff --git a/streaming/pom.xml b/streaming/pom.xml
index 39cbd0d00f951..7d409c5d3b076 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -20,13 +20,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../pom.xmlorg.apache.spark
- spark-streaming_2.10
+ spark-streaming_2.11streaming
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
index 66f646d7dc136..e6724feaee105 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
@@ -221,7 +221,7 @@ object StateSpec {
mappingFunction: JFunction3[KeyType, Optional[ValueType], State[StateType], MappedType]):
StateSpec[KeyType, ValueType, StateType, MappedType] = {
val wrappedFunc = (k: KeyType, v: Option[ValueType], s: State[StateType]) => {
- mappingFunction.call(k, Optional.ofNullable(v.get), s)
+ mappingFunction.call(k, JavaUtils.optionToOptional(v), s)
}
StateSpec.function(wrappedFunc)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
index 1d2244eaf22b3..6ab1956bed900 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
@@ -57,7 +57,8 @@ private[streaming] object MapWithStateRDDRecord {
val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
if (wrappedState.isRemoved) {
newStateMap.remove(key)
- } else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) {
+ } else if (wrappedState.isUpdated
+ || (wrappedState.exists && timeoutThresholdTime.isDefined)) {
newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
}
mappedData ++= returned
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index a5a01e77639c4..a3ad5eaa40edc 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -20,6 +20,7 @@ package org.apache.spark.streaming.scheduler
import scala.util.{Failure, Success, Try}
import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time}
import org.apache.spark.streaming.util.RecurringTimer
import org.apache.spark.util.{Clock, EventLoop, ManualClock, Utils}
@@ -243,6 +244,10 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
// Example: BlockRDDs are created in this thread, and it needs to access BlockManager
// Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed.
SparkEnv.set(ssc.env)
+
+ // Checkpoint all RDDs marked for checkpointing to ensure their lineages are
+ // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847).
+ ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true")
Try {
jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch
graph.generateJobs(time) // generate jobs using allocated block
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 9535c8e5b768a..3fed3d88354c7 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -23,10 +23,10 @@ import scala.collection.JavaConverters._
import scala.util.Failure
import org.apache.spark.Logging
-import org.apache.spark.rdd.PairRDDFunctions
+import org.apache.spark.rdd.{PairRDDFunctions, RDD}
import org.apache.spark.streaming._
import org.apache.spark.streaming.ui.UIUtils
-import org.apache.spark.util.{EventLoop, ThreadUtils, Utils}
+import org.apache.spark.util.{EventLoop, ThreadUtils}
private[scheduler] sealed trait JobSchedulerEvent
@@ -210,6 +210,9 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
s"""Streaming job from $batchLinkText""")
ssc.sc.setLocalProperty(BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString)
ssc.sc.setLocalProperty(OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString)
+ // Checkpoint all RDDs marked for checkpointing to ensure their lineages are
+ // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847).
+ ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true")
// We need to assign `eventLoop` to a temp variable. Otherwise, because
// `JobScheduler.stop(false)` may set `eventLoop` to null when this method is running, then
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
index 7635f79a3d2d1..81de07f933f8a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
@@ -37,10 +37,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") {
private def columns: Seq[Node] = {
Output Op Id
Description
-
Duration
+
Output Op Duration
Status
Job Id
-
Duration
+
Job Duration
Stages: Succeeded/Total
Tasks (for all stages): Succeeded/Total
Error
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 4a6b91fbc745e..786703eb9a84e 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -821,6 +821,75 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
checkpointWriter.stop()
}
+ test("SPARK-6847: stack overflow when updateStateByKey is followed by a checkpointed dstream") {
+ // In this test, there are two updateStateByKey operators. The RDD DAG is as follows:
+ //
+ // batch 1 batch 2 batch 3 ...
+ //
+ // 1) input rdd input rdd input rdd
+ // | | |
+ // v v v
+ // 2) cogroup rdd ---> cogroup rdd ---> cogroup rdd ...
+ // | / | / |
+ // v / v / v
+ // 3) map rdd --- map rdd --- map rdd ...
+ // | | |
+ // v v v
+ // 4) cogroup rdd ---> cogroup rdd ---> cogroup rdd ...
+ // | / | / |
+ // v / v / v
+ // 5) map rdd --- map rdd --- map rdd ...
+ //
+ // Every batch depends on its previous batch, so "updateStateByKey" needs to do checkpoint to
+ // break the RDD chain. However, before SPARK-6847, when the state RDD (layer 5) of the second
+ // "updateStateByKey" does checkpoint, it won't checkpoint the state RDD (layer 3) of the first
+ // "updateStateByKey" (Note: "updateStateByKey" has already marked that its state RDD (layer 3)
+ // should be checkpointed). Hence, the connections between layer 2 and layer 3 won't be broken
+ // and the RDD chain will grow infinitely and cause StackOverflow.
+ //
+ // Therefore SPARK-6847 introduces "spark.checkpoint.checkpointAllMarked" to force checkpointing
+ // all marked RDDs in the DAG to resolve this issue. (For the previous example, it will break
+ // connections between layer 2 and layer 3)
+ ssc = new StreamingContext(master, framework, batchDuration)
+ val batchCounter = new BatchCounter(ssc)
+ ssc.checkpoint(checkpointDir)
+ val inputDStream = new CheckpointInputDStream(ssc)
+ val updateFunc = (values: Seq[Int], state: Option[Int]) => {
+ Some(values.sum + state.getOrElse(0))
+ }
+ @volatile var shouldCheckpointAllMarkedRDDs = false
+ @volatile var rddsCheckpointed = false
+ inputDStream.map(i => (i, i))
+ .updateStateByKey(updateFunc).checkpoint(batchDuration)
+ .updateStateByKey(updateFunc).checkpoint(batchDuration)
+ .foreachRDD { rdd =>
+ /**
+ * Find all RDDs that are marked for checkpointing in the specified RDD and its ancestors.
+ */
+ def findAllMarkedRDDs(rdd: RDD[_]): List[RDD[_]] = {
+ val markedRDDs = rdd.dependencies.flatMap(dep => findAllMarkedRDDs(dep.rdd)).toList
+ if (rdd.checkpointData.isDefined) {
+ rdd :: markedRDDs
+ } else {
+ markedRDDs
+ }
+ }
+
+ shouldCheckpointAllMarkedRDDs =
+ Option(rdd.sparkContext.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)).
+ map(_.toBoolean).getOrElse(false)
+
+ val stateRDDs = findAllMarkedRDDs(rdd)
+ rdd.count()
+ // Check the two state RDDs are both checkpointed
+ rddsCheckpointed = stateRDDs.size == 2 && stateRDDs.forall(_.isCheckpointed)
+ }
+ ssc.start()
+ batchCounter.waitUntilBatchesCompleted(1, 10000)
+ assert(shouldCheckpointAllMarkedRDDs === true)
+ assert(rddsCheckpointed === true)
+ }
+
/**
* Advances the manual clock on the streaming scheduler by given number of batches.
* It also waits for the expected amount of time for each batch.
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
index c4ecebcacf3c8..96dd4757be855 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
@@ -143,8 +143,9 @@ class UISeleniumSuite
summaryText should contain ("Total delay:")
findAll(cssSelector("""#batch-job-table th""")).map(_.text).toSeq should be {
- List("Output Op Id", "Description", "Duration", "Status", "Job Id", "Duration",
- "Stages: Succeeded/Total", "Tasks (for all stages): Succeeded/Total", "Error")
+ List("Output Op Id", "Description", "Output Op Duration", "Status", "Job Id",
+ "Job Duration", "Stages: Succeeded/Total", "Tasks (for all stages): Succeeded/Total",
+ "Error")
}
// Check we have 2 output op ids
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala
index 5b13fd6ad611a..e8c814ba7184b 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala
@@ -190,6 +190,11 @@ class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with B
timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123))
+ // If a state is not set but timeoutThreshold is defined, we should ignore this state.
+ // Previously it threw NoSuchElementException (SPARK-13195).
+ assertRecordUpdate(initStates = Seq(), data = Seq("noop"),
+ timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
+ expectedStates = Nil, expectedTimingOutStates = Nil)
}
test("states generated by MapWithStateRDD") {
diff --git a/tags/pom.xml b/tags/pom.xml
index 9e4610dae7a65..3e8e6f6182875 100644
--- a/tags/pom.xml
+++ b/tags/pom.xml
@@ -21,13 +21,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../pom.xmlorg.apache.spark
- spark-test-tags_2.10
+ spark-test-tags_2.11jarSpark Project Test Tagshttp://spark.apache.org/
diff --git a/tools/pom.xml b/tools/pom.xml
index 30cbb6a5a59c7..b3a5ae2771241 100644
--- a/tools/pom.xml
+++ b/tools/pom.xml
@@ -19,13 +19,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../pom.xmlorg.apache.spark
- spark-tools_2.10
+ spark-tools_2.11tools
diff --git a/unsafe/pom.xml b/unsafe/pom.xml
index 21fef3415adce..75fea556eeae1 100644
--- a/unsafe/pom.xml
+++ b/unsafe/pom.xml
@@ -21,13 +21,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../pom.xmlorg.apache.spark
- spark-unsafe_2.10
+ spark-unsafe_2.11jarSpark Project Unsafehttp://spark.apache.org/
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
index b29bf6a464b30..18761bfd222a2 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
@@ -27,10 +27,14 @@ public final class Platform {
public static final int BYTE_ARRAY_OFFSET;
+ public static final int SHORT_ARRAY_OFFSET;
+
public static final int INT_ARRAY_OFFSET;
public static final int LONG_ARRAY_OFFSET;
+ public static final int FLOAT_ARRAY_OFFSET;
+
public static final int DOUBLE_ARRAY_OFFSET;
public static int getInt(Object object, long offset) {
@@ -168,13 +172,17 @@ public static void throwException(Throwable t) {
if (_UNSAFE != null) {
BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class);
+ SHORT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(short[].class);
INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class);
LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class);
+ FLOAT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(float[].class);
DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class);
} else {
BYTE_ARRAY_OFFSET = 0;
+ SHORT_ARRAY_OFFSET = 0;
INT_ARRAY_OFFSET = 0;
LONG_ARRAY_OFFSET = 0;
+ FLOAT_ARRAY_OFFSET = 0;
DOUBLE_ARRAY_OFFSET = 0;
}
}
diff --git a/yarn/pom.xml b/yarn/pom.xml
index a8c122fd40a1f..328bb6678db99 100644
--- a/yarn/pom.xml
+++ b/yarn/pom.xml
@@ -19,13 +19,13 @@
4.0.0org.apache.spark
- spark-parent_2.10
+ spark-parent_2.112.0.0-SNAPSHOT../pom.xmlorg.apache.spark
- spark-yarn_2.10
+ spark-yarn_2.11jarSpark Project YARN