diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
index 89eec7d4b7f61..c99a61f63ea2b 100644
--- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties
+++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
@@ -10,3 +10,4 @@ log4j.logger.org.eclipse.jetty=WARN
log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
+log4j.logger.org.apache.hadoop.yarn.util.RackResolver=WARN
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index f02b035a980b1..a1f7133f897ee 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -121,6 +121,14 @@ pre {
border: none;
}
+.description-input {
+ overflow: hidden;
+ text-overflow: ellipsis;
+ width: 100%;
+ white-space: nowrap;
+ display: block;
+}
+
.stacktrace-details {
max-height: 300px;
overflow-y: auto;
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index a0ee2a7cbb2a2..b28da192c1c0d 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -158,7 +158,7 @@ private[spark] class ExecutorAllocationManager(
"shuffle service. You may enable this through spark.shuffle.service.enabled.")
}
if (tasksPerExecutor == 0) {
- throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.cores")
+ throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.")
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index a0ce107f43b16..f9d4aa4240e9d 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import scala.collection.JavaConverters._
+import scala.collection.concurrent.TrieMap
import scala.collection.mutable.{HashMap, LinkedHashSet}
import org.apache.spark.serializer.KryoSerializer
@@ -46,7 +47,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Create a SparkConf that loads defaults from system properties and the classpath */
def this() = this(true)
- private[spark] val settings = new HashMap[String, String]()
+ private[spark] val settings = new TrieMap[String, String]()
if (loadDefaults) {
// Load any spark.* system properties
@@ -177,7 +178,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
}
/** Get all parameters as a list of pairs */
- def getAll: Array[(String, String)] = settings.clone().toArray
+ def getAll: Array[(String, String)] = settings.toArray
/** Get a parameter as an integer, falling back to a default if not set */
def getInt(key: String, defaultValue: Int): Int = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index a1dfb01062591..33a7aae5d3fcd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -168,7 +168,7 @@ private[spark] class TaskSchedulerImpl(
if (!hasLaunchedTask) {
logWarning("Initial job has not accepted any resources; " +
"check your cluster UI to ensure that workers are registered " +
- "and have sufficient memory")
+ "and have sufficient resources")
} else {
this.cancel()
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 662a7b91248aa..fa8a337ad63a8 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -92,7 +92,7 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade
}
override def deserializeStream(s: InputStream): DeserializationStream = {
- new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader)
+ new JavaDeserializationStream(s, defaultClassLoader)
}
def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
index b4677447c8872..fc1844600f1cb 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
@@ -22,20 +22,23 @@ import scala.util.Random
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.scheduler.SchedulingMode
+// scalastyle:off
/**
* Continuously generates jobs that expose various features of the WebUI (internal testing tool).
*
- * Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]
+ * Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR] [#job set (4 jobs per set)]
*/
+// scalastyle:on
private[spark] object UIWorkloadGenerator {
val NUM_PARTITIONS = 100
val INTER_JOB_WAIT_MS = 5000
def main(args: Array[String]) {
- if (args.length < 2) {
+ if (args.length < 3) {
println(
- "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]")
+ "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " +
+ "[master] [FIFO|FAIR] [#job set (4 jobs per set)]")
System.exit(1)
}
@@ -45,6 +48,7 @@ private[spark] object UIWorkloadGenerator {
if (schedulingMode == SchedulingMode.FAIR) {
conf.set("spark.scheduler.mode", "FAIR")
}
+ val nJobSet = args(2).toInt
val sc = new SparkContext(conf)
def setProperties(s: String) = {
@@ -84,7 +88,7 @@ private[spark] object UIWorkloadGenerator {
("Job with delays", baseData.map(x => Thread.sleep(100)).count)
)
- while (true) {
+ (1 to nJobSet).foreach { _ =>
for ((desc, job) <- jobs) {
new Thread {
override def run() {
@@ -101,5 +105,6 @@ private[spark] object UIWorkloadGenerator {
Thread.sleep(INTER_JOB_WAIT_MS)
}
}
+ sc.stop()
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
index 81212708ba524..045c69da06feb 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -64,7 +64,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
{job.jobId} {job.jobGroup.map(id => s"($id)").getOrElse("")}
- {lastStageDescription}
+ {lastStageDescription}
{lastStageName}
|
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala
index 1da7a988203db..479f967fb1541 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala
@@ -59,54 +59,86 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") {
val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable])
val poolTable = new PoolTable(pools, parent)
+ val shouldShowActiveStages = activeStages.nonEmpty
+ val shouldShowPendingStages = pendingStages.nonEmpty
+ val shouldShowCompletedStages = completedStages.nonEmpty
+ val shouldShowFailedStages = failedStages.nonEmpty
+
val summary: NodeSeq =
- {if (sc.isDefined) {
- // Total duration is not meaningful unless the UI is live
- -
- Total Duration:
- {UIUtils.formatDuration(now - sc.get.startTime)}
-
- }}
+ {
+ if (sc.isDefined) {
+ // Total duration is not meaningful unless the UI is live
+ -
+ Total Duration:
+ {UIUtils.formatDuration(now - sc.get.startTime)}
+
+ }
+ }
-
Scheduling Mode:
{listener.schedulingMode.map(_.toString).getOrElse("Unknown")}
- -
- Active Stages:
- {activeStages.size}
-
- -
- Pending Stages:
- {pendingStages.size}
-
- -
- Completed Stages:
- {numCompletedStages}
-
- -
- Failed Stages:
- {numFailedStages}
-
+ {
+ if (shouldShowActiveStages) {
+ -
+ Active Stages:
+ {activeStages.size}
+
+ }
+ }
+ {
+ if (shouldShowPendingStages) {
+ -
+ Pending Stages:
+ {pendingStages.size}
+
+ }
+ }
+ {
+ if (shouldShowCompletedStages) {
+ -
+ Completed Stages:
+ {numCompletedStages}
+
+ }
+ }
+ {
+ if (shouldShowFailedStages) {
+ -
+ Failed Stages:
+ {numFailedStages}
+
+ }
+ }
- val content = summary ++
- {if (sc.isDefined && isFairScheduler) {
- {pools.size} Fair Scheduler Pools ++ poolTable.toNodeSeq
- } else {
- Seq[Node]()
- }} ++
- Active Stages ({activeStages.size}) ++
- activeStagesTable.toNodeSeq ++
- Pending Stages ({pendingStages.size}) ++
- pendingStagesTable.toNodeSeq ++
- Completed Stages ({numCompletedStages}) ++
- completedStagesTable.toNodeSeq ++
- Failed Stages ({numFailedStages}) ++
+ var content = summary ++
+ {
+ if (sc.isDefined && isFairScheduler) {
+ {pools.size} Fair Scheduler Pools ++ poolTable.toNodeSeq
+ } else {
+ Seq[Node]()
+ }
+ }
+ if (shouldShowActiveStages) {
+ content ++= Active Stages ({activeStages.size}) ++
+ activeStagesTable.toNodeSeq
+ }
+ if (shouldShowPendingStages) {
+ content ++= Pending Stages ({pendingStages.size} ++
+ pendingStagesTable.toNodeSeq
+ }
+ if (shouldShowCompletedStages) {
+ content ++= Completed Stages ({numCompletedStages}) ++
+ completedStagesTable.toNodeSeq
+ }
+ if (shouldShowFailedStages) {
+ content ++= Failed Stages ({numFailedStages}) ++
failedStagesTable.toNodeSeq
-
+ }
UIUtils.headerSparkPage("Spark Stages (for all jobs)", content, parent)
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index e7d6244dcd679..703d43f9c640d 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -112,9 +112,8 @@ private[ui] class StageTableBase(
stageData <- listener.stageIdToData.get((s.stageId, s.attemptId))
desc <- stageData.description
} yield {
- {desc}
+ {desc}
}
-
{stageDesc.getOrElse("")} {killLink} {nameLink} {details}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala
index 948c350953e27..de58be38c7bfb 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala
@@ -54,7 +54,7 @@ object DenseGmmEM {
for (i <- 0 until clusters.k) {
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
- (clusters.weight(i), clusters.mu(i), clusters.sigma(i)))
+ (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma))
}
println("Cluster labels (first <= 100):")
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
index 8a13c74221546..2d6a825b61726 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
@@ -133,6 +133,12 @@ object GraphGenerators {
// This ensures that the 4 quadrants are the same size at all recursion levels
val numVertices = math.round(
math.pow(2.0, math.ceil(math.log(requestedNumVertices) / math.log(2.0)))).toInt
+ val numEdgesUpperBound =
+ math.pow(2.0, 2 * ((math.log(numVertices) / math.log(2.0)) - 1)).toInt
+ if (numEdgesUpperBound < numEdges) {
+ throw new IllegalArgumentException(
+ s"numEdges must be <= $numEdgesUpperBound but was $numEdges")
+ }
var edges: Set[Edge[Int]] = Set()
while (edges.size < numEdges) {
if (edges.size % 100 == 0) {
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
index 3abefbe52fa8a..8d9c8ddccbb3c 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
@@ -110,4 +110,14 @@ class GraphGeneratorsSuite extends FunSuite with LocalSparkContext {
}
}
+ test("SPARK-5064 GraphGenerators.rmatGraph numEdges upper bound") {
+ withSpark { sc =>
+ val g1 = GraphGenerators.rmatGraph(sc, 4, 4)
+ assert(g1.edges.count() === 4)
+ intercept[IllegalArgumentException] {
+ val g2 = GraphGenerators.rmatGraph(sc, 4, 8)
+ }
+ }
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 555da8c7e7ab3..430d763ef7ca7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -266,12 +266,16 @@ class PythonMLLibAPI extends Serializable {
k: Int,
maxIterations: Int,
runs: Int,
- initializationMode: String): KMeansModel = {
+ initializationMode: String,
+ seed: java.lang.Long): KMeansModel = {
val kMeansAlg = new KMeans()
.setK(k)
.setMaxIterations(maxIterations)
.setRuns(runs)
.setInitializationMode(initializationMode)
+
+ if (seed != null) kMeansAlg.setSeed(seed)
+
try {
kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
} finally {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
index d8e134619411b..899fe5e9e9cf2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
@@ -134,9 +134,7 @@ class GaussianMixtureEM private (
// diagonal covariance matrices using component variances
// derived from the samples
val (weights, gaussians) = initialModel match {
- case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) =>
- new MultivariateGaussian(mu, sigma)
- })
+ case Some(gmm) => (gmm.weights, gmm.gaussians)
case None => {
val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
@@ -176,10 +174,7 @@ class GaussianMixtureEM private (
iter += 1
}
- // Need to convert the breeze matrices to MLlib matrices
- val means = Array.tabulate(k) { i => gaussians(i).mu }
- val sigmas = Array.tabulate(k) { i => gaussians(i).sigma }
- new GaussianMixtureModel(weights, means, sigmas)
+ new GaussianMixtureModel(weights, gaussians)
}
/** Average of dense breeze vectors */
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 416cad080c408..1a2178ee7f711 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseVector => BreezeVector}
import org.apache.spark.rdd.RDD
-import org.apache.spark.mllib.linalg.{Matrix, Vector}
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
@@ -36,12 +36,13 @@ import org.apache.spark.mllib.util.MLUtils
* covariance matrix for Gaussian i
*/
class GaussianMixtureModel(
- val weight: Array[Double],
- val mu: Array[Vector],
- val sigma: Array[Matrix]) extends Serializable {
+ val weights: Array[Double],
+ val gaussians: Array[MultivariateGaussian]) extends Serializable {
+
+ require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
/** Number of gaussians in mixture */
- def k: Int = weight.length
+ def k: Int = weights.length
/** Maps given points to their cluster indices. */
def predict(points: RDD[Vector]): RDD[Int] = {
@@ -55,14 +56,10 @@ class GaussianMixtureModel(
*/
def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = {
val sc = points.sparkContext
- val dists = sc.broadcast {
- (0 until k).map { i =>
- new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix)
- }.toArray
- }
- val weights = sc.broadcast(weight)
+ val bcDists = sc.broadcast(gaussians)
+ val bcWeights = sc.broadcast(weights)
points.map { x =>
- computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k)
+ computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 54c301d3e9e14..11633e8242313 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -19,14 +19,14 @@ package org.apache.spark.mllib.clustering
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
-import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
/**
@@ -43,13 +43,14 @@ class KMeans private (
private var runs: Int,
private var initializationMode: String,
private var initializationSteps: Int,
- private var epsilon: Double) extends Serializable with Logging {
+ private var epsilon: Double,
+ private var seed: Long) extends Serializable with Logging {
/**
* Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,
- * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4}.
+ * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}.
*/
- def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4)
+ def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong())
/** Set the number of clusters to create (k). Default: 2. */
def setK(k: Int): this.type = {
@@ -112,6 +113,12 @@ class KMeans private (
this
}
+ /** Set the random seed for cluster initialization. */
+ def setSeed(seed: Long): this.type = {
+ this.seed = seed
+ this
+ }
+
/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
@@ -255,7 +262,7 @@ class KMeans private (
private def initRandom(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
// Sample all the cluster centers in one pass to avoid repeated scans
- val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq
+ val sample = data.takeSample(true, runs * k, new XORShiftRandom(this.seed).nextInt()).toSeq
Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v =>
new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm)
}.toArray)
@@ -272,45 +279,81 @@ class KMeans private (
*/
private def initKMeansParallel(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
- // Initialize each run's center to a random point
- val seed = new XORShiftRandom().nextInt()
+ // Initialize empty centers and point costs.
+ val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
+ var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache()
+
+ // Initialize each run's first center to a random point.
+ val seed = new XORShiftRandom(this.seed).nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
- val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
+ val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
+
+ /** Merges new centers to centers. */
+ def mergeNewCenters(): Unit = {
+ var r = 0
+ while (r < runs) {
+ centers(r) ++= newCenters(r)
+ newCenters(r).clear()
+ r += 1
+ }
+ }
// On each step, sample 2 * k points on average for each run with probability proportional
- // to their squared distance from that run's current centers
+ // to their squared distance from that run's centers. Note that only distances between points
+ // and new centers are computed in each iteration.
var step = 0
while (step < initializationSteps) {
- val bcCenters = data.context.broadcast(centers)
- val sumCosts = data.flatMap { point =>
- (0 until runs).map { r =>
- (r, KMeans.pointCost(bcCenters.value(r), point))
- }
- }.reduceByKey(_ + _).collectAsMap()
- val chosen = data.mapPartitionsWithIndex { (index, points) =>
+ val bcNewCenters = data.context.broadcast(newCenters)
+ val preCosts = costs
+ costs = data.zip(preCosts).map { case (point, cost) =>
+ Vectors.dense(
+ Array.tabulate(runs) { r =>
+ math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r))
+ })
+ }.cache()
+ val sumCosts = costs
+ .aggregate(Vectors.zeros(runs))(
+ seqOp = (s, v) => {
+ // s += v
+ axpy(1.0, v, s)
+ s
+ },
+ combOp = (s0, s1) => {
+ // s0 += s1
+ axpy(1.0, s1, s0)
+ s0
+ }
+ )
+ preCosts.unpersist(blocking = false)
+ val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) =>
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
- points.flatMap { p =>
- (0 until runs).filter { r =>
- rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r)
- }.map((_, p))
+ pointsWithCosts.flatMap { case (p, c) =>
+ val rs = (0 until runs).filter { r =>
+ rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r)
+ }
+ if (rs.length > 0) Some(p, rs) else None
}
}.collect()
- chosen.foreach { case (r, p) =>
- centers(r) += p.toDense
+ mergeNewCenters()
+ chosen.foreach { case (p, rs) =>
+ rs.foreach(newCenters(_) += p.toDense)
}
step += 1
}
+ mergeNewCenters()
+ costs.unpersist(blocking = false)
+
// Finally, we might have a set of more than k candidate centers for each run; weigh each
// candidate by the number of points in the dataset mapping to it and run a local k-means++
// on the weighted centers to pick just k of them
val bcCenters = data.context.broadcast(centers)
val weightMap = data.flatMap { p =>
- (0 until runs).map { r =>
+ Iterator.tabulate(runs) { r =>
((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
}
}.reduceByKey(_ + _).collectAsMap()
- val finalCenters = (0 until runs).map { r =>
+ val finalCenters = (0 until runs).par.map { r =>
val myCenters = centers(r).toArray
val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray
LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30)
@@ -333,7 +376,32 @@ object KMeans {
/**
* Trains a k-means model using the given set of parameters.
*
- * @param data training points stored as `RDD[Array[Double]]`
+ * @param data training points stored as `RDD[Vector]`
+ * @param k number of clusters
+ * @param maxIterations max number of iterations
+ * @param runs number of parallel runs, defaults to 1. The best model is returned.
+ * @param initializationMode initialization model, either "random" or "k-means||" (default).
+ * @param seed random seed value for cluster initialization
+ */
+ def train(
+ data: RDD[Vector],
+ k: Int,
+ maxIterations: Int,
+ runs: Int,
+ initializationMode: String,
+ seed: Long): KMeansModel = {
+ new KMeans().setK(k)
+ .setMaxIterations(maxIterations)
+ .setRuns(runs)
+ .setInitializationMode(initializationMode)
+ .setSeed(seed)
+ .run(data)
+ }
+
+ /**
+ * Trains a k-means model using the given set of parameters.
+ *
+ * @param data training points stored as `RDD[Vector]`
* @param k number of clusters
* @param maxIterations max number of iterations
* @param runs number of parallel runs, defaults to 1. The best model is returned.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index adbd8266ed6fa..7ee0224ad4662 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -50,13 +50,35 @@ sealed trait Vector extends Serializable {
override def equals(other: Any): Boolean = {
other match {
- case v: Vector =>
- util.Arrays.equals(this.toArray, v.toArray)
+ case v2: Vector => {
+ if (this.size != v2.size) return false
+ (this, v2) match {
+ case (s1: SparseVector, s2: SparseVector) =>
+ Vectors.equals(s1.indices, s1.values, s2.indices, s2.values)
+ case (s1: SparseVector, d1: DenseVector) =>
+ Vectors.equals(s1.indices, s1.values, 0 until d1.size, d1.values)
+ case (d1: DenseVector, s1: SparseVector) =>
+ Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values)
+ case (_, _) => util.Arrays.equals(this.toArray, v2.toArray)
+ }
+ }
case _ => false
}
}
- override def hashCode(): Int = util.Arrays.hashCode(this.toArray)
+ override def hashCode(): Int = {
+ var result: Int = size + 31
+ this.foreachActive { case (index, value) =>
+ // ignore explict 0 for comparison between sparse and dense
+ if (value != 0) {
+ result = 31 * result + index
+ // refer to {@link java.util.Arrays.equals} for hash algorithm
+ val bits = java.lang.Double.doubleToLongBits(value)
+ result = 31 * result + (bits ^ (bits >>> 32)).toInt
+ }
+ }
+ return result
+ }
/**
* Converts the instance to a breeze vector.
@@ -392,6 +414,33 @@ object Vectors {
}
squaredDistance
}
+
+ /**
+ * Check equality between sparse/dense vectors
+ */
+ private[mllib] def equals(
+ v1Indices: IndexedSeq[Int],
+ v1Values: Array[Double],
+ v2Indices: IndexedSeq[Int],
+ v2Values: Array[Double]): Boolean = {
+ val v1Size = v1Values.size
+ val v2Size = v2Values.size
+ var k1 = 0
+ var k2 = 0
+ var allEqual = true
+ while (allEqual) {
+ while (k1 < v1Size && v1Values(k1) == 0) k1 += 1
+ while (k2 < v2Size && v2Values(k2) == 0) k2 += 1
+
+ if (k1 >= v1Size || k2 >= v2Size) {
+ return k1 >= v1Size && k2 >= v2Size // check end alignment
+ }
+ allEqual = v1Indices(k1) == v2Indices(k2) && v1Values(k1) == v2Values(k2)
+ k1 += 1
+ k2 += 1
+ }
+ allEqual
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
index 06d8915f3bfa1..b60559c853a50 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
@@ -69,6 +69,11 @@ class CoordinateMatrix(
nRows
}
+ /** Transposes this CoordinateMatrix. */
+ def transpose(): CoordinateMatrix = {
+ new CoordinateMatrix(entries.map(x => MatrixEntry(x.j, x.i, x.value)), numCols(), numRows())
+ }
+
/** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */
def toIndexedRowMatrix(): IndexedRowMatrix = {
val nl = numCols()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
index 181f507516485..c518271f04729 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
@@ -75,6 +75,23 @@ class IndexedRowMatrix(
new RowMatrix(rows.map(_.vector), 0L, nCols)
}
+ /**
+ * Converts this matrix to a
+ * [[org.apache.spark.mllib.linalg.distributed.CoordinateMatrix]].
+ */
+ def toCoordinateMatrix(): CoordinateMatrix = {
+ val entries = rows.flatMap { row =>
+ val rowIndex = row.index
+ row.vector match {
+ case SparseVector(size, indices, values) =>
+ Iterator.tabulate(indices.size)(i => MatrixEntry(rowIndex, indices(i), values(i)))
+ case DenseVector(values) =>
+ Iterator.tabulate(values.size)(i => MatrixEntry(rowIndex, i, values(i)))
+ }
+ }
+ new CoordinateMatrix(entries, numRows(), numCols())
+ }
+
/**
* Computes the singular value decomposition of this IndexedRowMatrix.
* Denote this matrix by A (m x n), this will compute matrices U, S, V such that A = U * S * V'.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index cf51d041c65a9..ed8e6a796f8c4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -68,6 +68,15 @@ case class BoostingStrategy(
@Experimental
object BoostingStrategy {
+ /**
+ * Returns default configuration for the boosting algorithm
+ * @param algo Learning goal. Supported: "Classification" or "Regression"
+ * @return Configuration for boosting algorithm
+ */
+ def defaultParams(algo: String): BoostingStrategy = {
+ defaultParams(Algo.fromString(algo))
+ }
+
/**
* Returns default configuration for the boosting algorithm
* @param algo Learning goal. Supported:
@@ -75,15 +84,15 @@ object BoostingStrategy {
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
* @return Configuration for boosting algorithm
*/
- def defaultParams(algo: String): BoostingStrategy = {
- val treeStrategy = Strategy.defaultStrategy(algo)
- treeStrategy.maxDepth = 3
+ def defaultParams(algo: Algo): BoostingStrategy = {
+ val treeStragtegy = Strategy.defaultStategy(algo)
+ treeStragtegy.maxDepth = 3
algo match {
- case "Classification" =>
- treeStrategy.numClasses = 2
- new BoostingStrategy(treeStrategy, LogLoss)
- case "Regression" =>
- new BoostingStrategy(treeStrategy, SquaredError)
+ case Algo.Classification =>
+ treeStragtegy.numClasses = 2
+ new BoostingStrategy(treeStragtegy, LogLoss)
+ case Algo.Regression =>
+ new BoostingStrategy(treeStragtegy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index d5cd89ab94e81..972959885f396 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -173,11 +173,19 @@ object Strategy {
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
* @param algo "Classification" or "Regression"
*/
- def defaultStrategy(algo: String): Strategy = algo match {
- case "Classification" =>
+ def defaultStrategy(algo: String): Strategy = {
+ defaultStategy(Algo.fromString(algo))
+ }
+
+ /**
+ * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
+ * @param algo Algo.Classification or Algo.Regression
+ */
+ def defaultStategy(algo: Algo): Strategy = algo match {
+ case Algo.Classification =>
new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
numClasses = 2)
- case "Regression" =>
+ case Algo.Regression =>
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
numClasses = 0)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
index 9da5495741a80..198997b5bb2b2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.{Vectors, Matrices}
+import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -39,9 +40,9 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
val seeds = Array(314589, 29032897, 50181, 494821, 4660)
seeds.foreach { seed =>
val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data)
- assert(gmm.weight(0) ~== Ew absTol 1E-5)
- assert(gmm.mu(0) ~== Emu absTol 1E-5)
- assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
+ assert(gmm.weights(0) ~== Ew absTol 1E-5)
+ assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5)
+ assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5)
}
}
@@ -57,8 +58,10 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
Array(0.5, 0.5),
- Array(Vectors.dense(-1.0), Vectors.dense(1.0)),
- Array(Matrices.dense(1, 1, Array(1.0)), Matrices.dense(1, 1, Array(1.0)))
+ Array(
+ new MultivariateGaussian(Vectors.dense(-1.0), Matrices.dense(1, 1, Array(1.0))),
+ new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0)))
+ )
)
val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
@@ -70,11 +73,11 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
.setInitialModel(initialGmm)
.run(data)
- assert(gmm.weight(0) ~== Ew(0) absTol 1E-3)
- assert(gmm.weight(1) ~== Ew(1) absTol 1E-3)
- assert(gmm.mu(0) ~== Emu(0) absTol 1E-3)
- assert(gmm.mu(1) ~== Emu(1) absTol 1E-3)
- assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3)
- assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3)
+ assert(gmm.weights(0) ~== Ew(0) absTol 1E-3)
+ assert(gmm.weights(1) ~== Ew(1) absTol 1E-3)
+ assert(gmm.gaussians(0).mu ~== Emu(0) absTol 1E-3)
+ assert(gmm.gaussians(1).mu ~== Emu(1) absTol 1E-3)
+ assert(gmm.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
+ assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 9ebef8466c831..caee5917000aa 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -90,6 +90,27 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
assert(model.clusterCenters.size === 3)
}
+ test("deterministic initialization") {
+ // Create a large-ish set of points for clustering
+ val points = List.tabulate(1000)(n => Vectors.dense(n, n))
+ val rdd = sc.parallelize(points, 3)
+
+ for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) {
+ // Create three deterministic models and compare cluster means
+ val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1,
+ initializationMode = initMode, seed = 42)
+ val centers1 = model1.clusterCenters
+
+ val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1,
+ initializationMode = initMode, seed = 42)
+ val centers2 = model2.clusterCenters
+
+ centers1.zip(centers2).foreach { case (c1, c2) =>
+ assert(c1 ~== c2 absTol 1E-14)
+ }
+ }
+ }
+
test("single cluster with big dataset") {
val smallData = Array(
Vectors.dense(1.0, 2.0, 6.0),
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 85ac8ccebfc59..5def899cea117 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -89,6 +89,24 @@ class VectorsSuite extends FunSuite {
}
}
+ test("vectors equals with explicit 0") {
+ val dv1 = Vectors.dense(Array(0, 0.9, 0, 0.8, 0))
+ val sv1 = Vectors.sparse(5, Array(1, 3), Array(0.9, 0.8))
+ val sv2 = Vectors.sparse(5, Array(0, 1, 2, 3, 4), Array(0, 0.9, 0, 0.8, 0))
+
+ val vectors = Seq(dv1, sv1, sv2)
+ for (v <- vectors; u <- vectors) {
+ assert(v === u)
+ assert(v.## === u.##)
+ }
+
+ val another = Vectors.sparse(5, Array(0, 1, 3), Array(0, 0.9, 0.2))
+ for (v <- vectors) {
+ assert(v != another)
+ assert(v.## != another.##)
+ }
+ }
+
test("indexing dense vectors") {
val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0)
assert(vec(0) === 1.0)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
index f8709751efce6..80bef814ce50d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
@@ -73,6 +73,11 @@ class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext {
assert(mat.toBreeze() === expected)
}
+ test("transpose") {
+ val transposed = mat.transpose()
+ assert(mat.toBreeze().t === transposed.toBreeze())
+ }
+
test("toIndexedRowMatrix") {
val indexedRowMatrix = mat.toIndexedRowMatrix()
val expected = BDM(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
index 741cd4997b853..b86c2ca5ff136 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
@@ -80,6 +80,14 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext {
assert(rowMat.rows.collect().toSeq === data.map(_.vector).toSeq)
}
+ test("toCoordinateMatrix") {
+ val idxRowMat = new IndexedRowMatrix(indexedRows)
+ val coordMat = idxRowMat.toCoordinateMatrix()
+ assert(coordMat.numRows() === m)
+ assert(coordMat.numCols() === n)
+ assert(coordMat.toBreeze() === idxRowMat.toBreeze())
+ }
+
test("multiply a local matrix") {
val A = new IndexedRowMatrix(indexedRows)
val B = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0))
diff --git a/pom.xml b/pom.xml
index f4466e56c2a53..b993391b15042 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1507,7 +1507,7 @@
mapr3
1.0.3-mapr-3.0.3
- 2.3.0-mapr-4.0.0-FCS
+ 2.4.1-mapr-1408
0.94.17-mapr-1405
3.4.5-mapr-1406
@@ -1516,8 +1516,8 @@
mapr4
- 2.3.0-mapr-4.0.0-FCS
- 2.3.0-mapr-4.0.0-FCS
+ 2.4.1-mapr-1408
+ 2.4.1-mapr-1408
0.94.17-mapr-1405-4.0.0-FCS
3.4.5-mapr-1406
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 95fef23ee4f39..127973b658190 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -86,6 +86,10 @@ object MimaExcludes {
// SPARK-5270
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.api.java.JavaRDDLike.isEmpty")
+ ) ++ Seq(
+ // SPARK-5297 Java FileStream do not work with custom key/values
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream")
)
case v if v.startsWith("1.2") =>
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index e2492eef5bd6a..6b713aa39374e 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -78,10 +78,10 @@ def predict(self, x):
class KMeans(object):
@classmethod
- def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"):
+ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None):
"""Train a k-means clustering model."""
model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations,
- runs, initializationMode)
+ runs, initializationMode, seed)
centers = callJavaFunc(rdd.context, model.clusterCenters)
return KMeansModel([c.toArray() for c in centers])
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 140c22b5fd4e8..f48e3d6dacb4b 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -140,7 +140,7 @@ class ListTests(PySparkTestCase):
as NumPy arrays.
"""
- def test_clustering(self):
+ def test_kmeans(self):
from pyspark.mllib.clustering import KMeans
data = [
[0, 1.1],
@@ -152,6 +152,21 @@ def test_clustering(self):
self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1]))
self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3]))
+ def test_kmeans_deterministic(self):
+ from pyspark.mllib.clustering import KMeans
+ X = range(0, 100, 10)
+ Y = range(0, 100, 10)
+ data = [[x, y] for x, y in zip(X, Y)]
+ clusters1 = KMeans.train(self.sc.parallelize(data),
+ 3, initializationMode="k-means||", seed=42)
+ clusters2 = KMeans.train(self.sc.parallelize(data),
+ 3, initializationMode="k-means||", seed=42)
+ centers1 = clusters1.centers
+ centers2 = clusters2.centers
+ for c1, c2 in zip(centers1, centers2):
+ # TODO: Allow small numeric difference.
+ self.assertTrue(array_equal(c1, c2))
+
def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
from pyspark.mllib.tree import DecisionTree
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 208ec92987ac8..41bb4f012f2e1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import scala.util.hashing.MurmurHash3
+
import org.apache.spark.sql.catalyst.expressions.GenericRow
@@ -32,7 +34,7 @@ object Row {
* }
* }}}
*/
- def unapplySeq(row: Row): Some[Seq[Any]] = Some(row)
+ def unapplySeq(row: Row): Some[Seq[Any]] = Some(row.toSeq)
/**
* This method can be used to construct a [[Row]] with the given values.
@@ -43,6 +45,16 @@ object Row {
* This method can be used to construct a [[Row]] from a [[Seq]] of values.
*/
def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray)
+
+ def fromTuple(tuple: Product): Row = fromSeq(tuple.productIterator.toSeq)
+
+ /**
+ * Merge multiple rows into a single row, one after another.
+ */
+ def merge(rows: Row*): Row = {
+ // TODO: Improve the performance of this if used in performance critical part.
+ new GenericRow(rows.flatMap(_.toSeq).toArray)
+ }
}
@@ -103,7 +115,13 @@ object Row {
*
* @group row
*/
-trait Row extends Seq[Any] with Serializable {
+trait Row extends Serializable {
+ /** Number of elements in the Row. */
+ def size: Int = length
+
+ /** Number of elements in the Row. */
+ def length: Int
+
/**
* Returns the value at position i. If the value is null, null is returned. The following
* is a mapping between Spark SQL types and return types:
@@ -291,12 +309,61 @@ trait Row extends Seq[Any] with Serializable {
/** Returns true if there are any NULL values in this row. */
def anyNull: Boolean = {
- val l = length
+ val len = length
var i = 0
- while (i < l) {
+ while (i < len) {
if (isNullAt(i)) { return true }
i += 1
}
false
}
+
+ override def equals(that: Any): Boolean = that match {
+ case null => false
+ case that: Row =>
+ if (this.length != that.length) {
+ return false
+ }
+ var i = 0
+ val len = this.length
+ while (i < len) {
+ if (apply(i) != that.apply(i)) {
+ return false
+ }
+ i += 1
+ }
+ true
+ case _ => false
+ }
+
+ override def hashCode: Int = {
+ // Using Scala's Seq hash code implementation.
+ var n = 0
+ var h = MurmurHash3.seqSeed
+ val len = length
+ while (n < len) {
+ h = MurmurHash3.mix(h, apply(n).##)
+ n += 1
+ }
+ MurmurHash3.finalizeHash(h, n)
+ }
+
+ /* ---------------------- utility methods for Scala ---------------------- */
+
+ /**
+ * Return a Scala Seq representing the row. ELements are placed in the same order in the Seq.
+ */
+ def toSeq: Seq[Any]
+
+ /** Displays all elements of this sequence in a string (without a separator). */
+ def mkString: String = toSeq.mkString
+
+ /** Displays all elements of this sequence in a string using a separator string. */
+ def mkString(sep: String): String = toSeq.mkString(sep)
+
+ /**
+ * Displays all elements of this traversable or iterator in a string using
+ * start, end, and separator strings.
+ */
+ def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
index 93d74adbcc957..366be00473d1c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
@@ -25,15 +25,42 @@ import scala.util.parsing.input.CharArrayReader.EofCh
import org.apache.spark.sql.catalyst.plans.logical._
+private[sql] object KeywordNormalizer {
+ def apply(str: String) = str.toLowerCase()
+}
+
private[sql] abstract class AbstractSparkSQLParser
extends StandardTokenParsers with PackratParsers {
- def apply(input: String): LogicalPlan = phrase(start)(new lexical.Scanner(input)) match {
- case Success(plan, _) => plan
- case failureOrError => sys.error(failureOrError.toString)
+ def apply(input: String): LogicalPlan = {
+ // Initialize the Keywords.
+ lexical.initialize(reservedWords)
+ phrase(start)(new lexical.Scanner(input)) match {
+ case Success(plan, _) => plan
+ case failureOrError => sys.error(failureOrError.toString)
+ }
}
- protected case class Keyword(str: String)
+ protected case class Keyword(str: String) {
+ def normalize = KeywordNormalizer(str)
+ def parser: Parser[String] = normalize
+ }
+
+ protected implicit def asParser(k: Keyword): Parser[String] = k.parser
+
+ // By default, use Reflection to find the reserved words defined in the sub class.
+ // NOTICE, Since the Keyword properties defined by sub class, we couldn't call this
+ // method during the parent class instantiation, because the sub class instance
+ // isn't created yet.
+ protected lazy val reservedWords: Seq[String] =
+ this
+ .getClass
+ .getMethods
+ .filter(_.getReturnType == classOf[Keyword])
+ .map(_.invoke(this).asInstanceOf[Keyword].normalize)
+
+ // Set the keywords as empty by default, will change that later.
+ override val lexical = new SqlLexical
protected def start: Parser[LogicalPlan]
@@ -52,18 +79,27 @@ private[sql] abstract class AbstractSparkSQLParser
}
}
-class SqlLexical(val keywords: Seq[String]) extends StdLexical {
+class SqlLexical extends StdLexical {
case class FloatLit(chars: String) extends Token {
override def toString = chars
}
- reserved ++= keywords.flatMap(w => allCaseVersions(w))
+ /* This is a work around to support the lazy setting */
+ def initialize(keywords: Seq[String]): Unit = {
+ reserved.clear()
+ reserved ++= keywords
+ }
delimiters += (
"@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>"
)
+ protected override def processIdent(name: String) = {
+ val token = KeywordNormalizer(name)
+ if (reserved contains token) Keyword(token) else Identifier(name)
+ }
+
override lazy val token: Parser[Token] =
( identChar ~ (identChar | digit).* ^^
{ case first ~ rest => processIdent((first :: rest).mkString) }
@@ -94,14 +130,5 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical {
| '-' ~ '-' ~ chrExcept(EofCh, '\n').*
| '/' ~ '*' ~ failure("unclosed comment")
).*
-
- /** Generate all variations of upper and lower case of a given string */
- def allCaseVersions(s: String, prefix: String = ""): Stream[String] = {
- if (s.isEmpty) {
- Stream(prefix)
- } else {
- allCaseVersions(s.tail, prefix + s.head.toLower) #:::
- allCaseVersions(s.tail, prefix + s.head.toUpper)
- }
- }
}
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index d280db83b26f7..191d16fb10b5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -84,8 +84,9 @@ trait ScalaReflection {
}
def convertRowToScala(r: Row, schema: StructType): Row = {
+ // TODO: This is very slow!!!
new GenericRow(
- r.zip(schema.fields.map(_.dataType))
+ r.toSeq.zip(schema.fields.map(_.dataType))
.map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 0b36d8b9bfce5..eaadbe9fd5099 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -36,9 +36,8 @@ import org.apache.spark.sql.types._
* for a SQL like language should checkout the HiveQL support in the sql/hive sub-project.
*/
class SqlParser extends AbstractSparkSQLParser {
- protected implicit def asParser(k: Keyword): Parser[String] =
- lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
-
+ // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
+ // properties via reflection the class in runtime for constructing the SqlLexical object
protected val ABS = Keyword("ABS")
protected val ALL = Keyword("ALL")
protected val AND = Keyword("AND")
@@ -107,16 +106,6 @@ class SqlParser extends AbstractSparkSQLParser {
protected val WHEN = Keyword("WHEN")
protected val WHERE = Keyword("WHERE")
- // Use reflection to find the reserved words defined in this class.
- protected val reservedWords =
- this
- .getClass
- .getMethods
- .filter(_.getReturnType == classOf[Keyword])
- .map(_.invoke(this).asInstanceOf[Keyword].str)
-
- override val lexical = new SqlLexical(reservedWords)
-
protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = {
exprs.zipWithIndex.map {
case (ne: NamedExpression, _) => ne
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 26c855878d202..417659eed5957 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -272,9 +272,6 @@ package object dsl {
def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) =
Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)
- def sfilter(dynamicUdf: (DynamicRow) => Boolean) =
- Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan)
-
def sample(
fraction: Double,
withReplacement: Boolean = true,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 1a2133bbbcec7..ece5ee73618cb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -407,7 +407,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
val casts = from.fields.zip(to.fields).map {
case (fromField, toField) => cast(fromField.dataType, toField.dataType)
}
- buildCast[Row](_, row => Row(row.zip(casts).map {
+ // TODO: This is very slow!
+ buildCast[Row](_, row => Row(row.toSeq.zip(casts).map {
case (v, cast) => if (v == null) null else cast(v)
}: _*))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index e7e81a21fdf03..db5d897ee569f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -105,45 +105,45 @@ class JoinedRow extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -154,8 +154,16 @@ class JoinedRow extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
@@ -197,45 +205,45 @@ class JoinedRow2 extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -246,8 +254,16 @@ class JoinedRow2 extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
@@ -283,45 +299,45 @@ class JoinedRow3 extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -332,8 +348,16 @@ class JoinedRow3 extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
@@ -369,45 +393,45 @@ class JoinedRow4 extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -418,8 +442,16 @@ class JoinedRow4 extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
@@ -455,45 +487,45 @@ class JoinedRow5 extends Row {
this
}
- def iterator = row1.iterator ++ row2.iterator
+ override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- def length = row1.length + row2.length
+ override def length = row1.length + row2.length
- def apply(i: Int) =
- if (i < row1.size) row1(i) else row2(i - row1.size)
+ override def apply(i: Int) =
+ if (i < row1.length) row1(i) else row2(i - row1.length)
- def isNullAt(i: Int) =
- if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+ override def isNullAt(i: Int) =
+ if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
- def getInt(i: Int): Int =
- if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+ override def getInt(i: Int): Int =
+ if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
- def getLong(i: Int): Long =
- if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+ override def getLong(i: Int): Long =
+ if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
- def getDouble(i: Int): Double =
- if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+ override def getDouble(i: Int): Double =
+ if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
- def getBoolean(i: Int): Boolean =
- if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+ override def getBoolean(i: Int): Boolean =
+ if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
- def getShort(i: Int): Short =
- if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+ override def getShort(i: Int): Short =
+ if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
- def getByte(i: Int): Byte =
- if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+ override def getByte(i: Int): Byte =
+ if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
- def getFloat(i: Int): Float =
- if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+ override def getFloat(i: Int): Float =
+ if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
- def getString(i: Int): String =
- if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getString(i: Int): String =
+ if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
override def getAs[T](i: Int): T =
- if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+ if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- def copy() = {
- val totalSize = row1.size + row2.size
+ override def copy() = {
+ val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
@@ -504,7 +536,15 @@ class JoinedRow5 extends Row {
}
override def toString() = {
- val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
- s"[${row.mkString(",")}]"
+ // Make sure toString never throws NullPointerException.
+ if ((row1 eq null) && (row2 eq null)) {
+ "[ empty row ]"
+ } else if (row1 eq null) {
+ row2.mkString("[", ",", "]")
+ } else if (row2 eq null) {
+ row1.mkString("[", ",", "]")
+ } else {
+ mkString("[", ",", "]")
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 37d9f0ed5c79e..7434165f654f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -209,6 +209,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
override def length: Int = values.length
+ override def toSeq: Seq[Any] = values.map(_.boxed).toSeq
+
override def setNullAt(i: Int): Unit = {
values(i).isNull = true
}
@@ -231,8 +233,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
if (value == null) setNullAt(ordinal) else values(ordinal).update(value)
}
- override def iterator: Iterator[Any] = values.map(_.boxed).iterator
-
override def setString(ordinal: Int, value: String) = update(ordinal, value)
override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
deleted file mode 100644
index 8328278544a1e..0000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import scala.language.dynamics
-
-import org.apache.spark.sql.types.DataType
-
-/**
- * The data type representing [[DynamicRow]] values.
- */
-case object DynamicType extends DataType
-
-/**
- * Wrap a [[Row]] as a [[DynamicRow]].
- */
-case class WrapDynamic(children: Seq[Attribute]) extends Expression {
- type EvaluatedType = DynamicRow
-
- def nullable = false
-
- def dataType = DynamicType
-
- override def eval(input: Row): DynamicRow = input match {
- // Avoid copy for generic rows.
- case g: GenericRow => new DynamicRow(children, g.values)
- case otherRowType => new DynamicRow(children, otherRowType.toArray)
- }
-}
-
-/**
- * DynamicRows use scala's Dynamic trait to emulate an ORM of in a dynamically typed language.
- * Since the type of the column is not known at compile time, all attributes are converted to
- * strings before being passed to the function.
- */
-class DynamicRow(val schema: Seq[Attribute], values: Array[Any])
- extends GenericRow(values) with Dynamic {
-
- def selectDynamic(attributeName: String): String = {
- val ordinal = schema.indexWhere(_.name == attributeName)
- values(ordinal).toString
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index cc97cb4f50b69..69397a73a8880 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -77,14 +77,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
""".children : Seq[Tree]
}
- val iteratorFunction = {
- val allColumns = (0 until expressions.size).map { i =>
- val iLit = ru.Literal(Constant(i))
- q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }"
- }
- q"override def iterator = Iterator[Any](..$allColumns)"
- }
-
val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)"""
val applyFunction = {
val cases = (0 until expressions.size).map { i =>
@@ -191,20 +183,26 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}
"""
+ val allColumns = (0 until expressions.size).map { i =>
+ val iLit = ru.Literal(Constant(i))
+ q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }"
+ }
+
val copyFunction =
- q"""
- override def copy() = new $genericRowType(this.toArray)
- """
+ q"override def copy() = new $genericRowType(Array[Any](..$allColumns))"
+
+ val toSeqFunction =
+ q"override def toSeq: Seq[Any] = Seq(..$allColumns)"
val classBody =
nullFunctions ++ (
lengthDef +:
- iteratorFunction +:
applyFunction +:
updateFunction +:
equalsFunction +:
hashCodeFunction +:
copyFunction +:
+ toSeqFunction +:
(tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions))
val code = q"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index c22b8426841da..8df150e2f855f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -44,7 +44,7 @@ trait MutableRow extends Row {
*/
object EmptyRow extends Row {
override def apply(i: Int): Any = throw new UnsupportedOperationException
- override def iterator = Iterator.empty
+ override def toSeq = Seq.empty
override def length = 0
override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException
override def getInt(i: Int): Int = throw new UnsupportedOperationException
@@ -70,7 +70,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
def this(size: Int) = this(new Array[Any](size))
- override def iterator = values.iterator
+ override def toSeq = values.toSeq
override def length = values.length
@@ -119,7 +119,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
}
// Custom hashCode function that matches the efficient code generated version.
- override def hashCode(): Int = {
+ override def hashCode: Int = {
var result: Int = 37
var i = 0
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 1483beacc9088..9628e93274a11 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -238,16 +238,11 @@ case class Rollup(
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output = child.output
- override lazy val statistics: Statistics =
- if (output.forall(_.dataType.isInstanceOf[NativeType])) {
- val limit = limitExpr.eval(null).asInstanceOf[Int]
- val sizeInBytes = (limit: Long) * output.map { a =>
- NativeType.defaultSizeOf(a.dataType.asInstanceOf[NativeType])
- }.sum
- Statistics(sizeInBytes = sizeInBytes)
- } else {
- Statistics(sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product)
- }
+ override lazy val statistics: Statistics = {
+ val limit = limitExpr.eval(null).asInstanceOf[Int]
+ val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
+ Statistics(sizeInBytes = sizeInBytes)
+ }
}
case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
index bcd74603d4013..9f30f40a173e0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
@@ -215,6 +215,9 @@ abstract class DataType {
case _ => false
}
+ /** The default size of a value of this data type. */
+ def defaultSize: Int
+
def isPrimitive: Boolean = false
def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase
@@ -235,33 +238,25 @@ abstract class DataType {
* @group dataType
*/
@DeveloperApi
-case object NullType extends DataType
+case object NullType extends DataType {
+ override def defaultSize: Int = 1
+}
-object NativeType {
+protected[sql] object NativeType {
val all = Seq(
IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
def unapply(dt: DataType): Boolean = all.contains(dt)
-
- val defaultSizeOf: Map[NativeType, Int] = Map(
- IntegerType -> 4,
- BooleanType -> 1,
- LongType -> 8,
- DoubleType -> 8,
- FloatType -> 4,
- ShortType -> 2,
- ByteType -> 1,
- StringType -> 4096)
}
-trait PrimitiveType extends DataType {
+protected[sql] trait PrimitiveType extends DataType {
override def isPrimitive = true
}
-object PrimitiveType {
+protected[sql] object PrimitiveType {
private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all
private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap
@@ -276,7 +271,7 @@ object PrimitiveType {
}
}
-abstract class NativeType extends DataType {
+protected[sql] abstract class NativeType extends DataType {
private[sql] type JvmType
@transient private[sql] val tag: TypeTag[JvmType]
private[sql] val ordering: Ordering[JvmType]
@@ -300,6 +295,11 @@ case object StringType extends NativeType with PrimitiveType {
private[sql] type JvmType = String
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
+
+ /**
+ * The default size of a value of the StringType is 4096 bytes.
+ */
+ override def defaultSize: Int = 4096
}
@@ -324,6 +324,11 @@ case object BinaryType extends NativeType with PrimitiveType {
x.length - y.length
}
}
+
+ /**
+ * The default size of a value of the BinaryType is 4096 bytes.
+ */
+ override def defaultSize: Int = 4096
}
@@ -339,6 +344,11 @@ case object BooleanType extends NativeType with PrimitiveType {
private[sql] type JvmType = Boolean
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
+
+ /**
+ * The default size of a value of the BooleanType is 1 byte.
+ */
+ override def defaultSize: Int = 1
}
@@ -359,6 +369,11 @@ case object TimestampType extends NativeType {
private[sql] val ordering = new Ordering[JvmType] {
def compare(x: Timestamp, y: Timestamp) = x.compareTo(y)
}
+
+ /**
+ * The default size of a value of the TimestampType is 8 bytes.
+ */
+ override def defaultSize: Int = 8
}
@@ -379,10 +394,15 @@ case object DateType extends NativeType {
private[sql] val ordering = new Ordering[JvmType] {
def compare(x: Date, y: Date) = x.compareTo(y)
}
+
+ /**
+ * The default size of a value of the DateType is 8 bytes.
+ */
+ override def defaultSize: Int = 8
}
-abstract class NumericType extends NativeType with PrimitiveType {
+protected[sql] abstract class NumericType extends NativeType with PrimitiveType {
// Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
// implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a
// type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets
@@ -392,13 +412,13 @@ abstract class NumericType extends NativeType with PrimitiveType {
}
-object NumericType {
+protected[sql] object NumericType {
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
}
/** Matcher for any expressions that evaluate to [[IntegralType]]s */
-object IntegralType {
+protected[sql] object IntegralType {
def unapply(a: Expression): Boolean = a match {
case e: Expression if e.dataType.isInstanceOf[IntegralType] => true
case _ => false
@@ -406,7 +426,7 @@ object IntegralType {
}
-sealed abstract class IntegralType extends NumericType {
+protected[sql] sealed abstract class IntegralType extends NumericType {
private[sql] val integral: Integral[JvmType]
}
@@ -425,6 +445,11 @@ case object LongType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Long]]
private[sql] val integral = implicitly[Integral[Long]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
+
+ /**
+ * The default size of a value of the LongType is 8 bytes.
+ */
+ override def defaultSize: Int = 8
}
@@ -442,6 +467,11 @@ case object IntegerType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Int]]
private[sql] val integral = implicitly[Integral[Int]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
+
+ /**
+ * The default size of a value of the IntegerType is 4 bytes.
+ */
+ override def defaultSize: Int = 4
}
@@ -459,6 +489,11 @@ case object ShortType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Short]]
private[sql] val integral = implicitly[Integral[Short]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
+
+ /**
+ * The default size of a value of the ShortType is 2 bytes.
+ */
+ override def defaultSize: Int = 2
}
@@ -476,11 +511,16 @@ case object ByteType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Byte]]
private[sql] val integral = implicitly[Integral[Byte]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
+
+ /**
+ * The default size of a value of the ByteType is 1 byte.
+ */
+ override def defaultSize: Int = 1
}
/** Matcher for any expressions that evaluate to [[FractionalType]]s */
-object FractionalType {
+protected[sql] object FractionalType {
def unapply(a: Expression): Boolean = a match {
case e: Expression if e.dataType.isInstanceOf[FractionalType] => true
case _ => false
@@ -488,7 +528,7 @@ object FractionalType {
}
-sealed abstract class FractionalType extends NumericType {
+protected[sql] sealed abstract class FractionalType extends NumericType {
private[sql] val fractional: Fractional[JvmType]
private[sql] val asIntegral: Integral[JvmType]
}
@@ -530,6 +570,11 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)"
case None => "DecimalType()"
}
+
+ /**
+ * The default size of a value of the DecimalType is 4096 bytes.
+ */
+ override def defaultSize: Int = 4096
}
@@ -580,6 +625,11 @@ case object DoubleType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[Double]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = DoubleAsIfIntegral
+
+ /**
+ * The default size of a value of the DoubleType is 8 bytes.
+ */
+ override def defaultSize: Int = 8
}
@@ -598,6 +648,11 @@ case object FloatType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[Float]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = FloatAsIfIntegral
+
+ /**
+ * The default size of a value of the FloatType is 4 bytes.
+ */
+ override def defaultSize: Int = 4
}
@@ -636,6 +691,12 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
("type" -> typeName) ~
("elementType" -> elementType.jsonValue) ~
("containsNull" -> containsNull)
+
+ /**
+ * The default size of a value of the ArrayType is 100 * the default size of the element type.
+ * (We assume that there are 100 elements).
+ */
+ override def defaultSize: Int = 100 * elementType.defaultSize
}
@@ -805,6 +866,11 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
override def length: Int = fields.length
override def iterator: Iterator[StructField] = fields.iterator
+
+ /**
+ * The default size of a value of the StructType is the total default sizes of all field types.
+ */
+ override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum
}
@@ -848,6 +914,13 @@ case class MapType(
("keyType" -> keyType.jsonValue) ~
("valueType" -> valueType.jsonValue) ~
("valueContainsNull" -> valueContainsNull)
+
+ /**
+ * The default size of a value of the MapType is
+ * 100 * (the default size of the key type + the default size of the value type).
+ * (We assume that there are 100 elements).
+ */
+ override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize)
}
@@ -896,4 +969,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
* Class object for the UserType
*/
def userClass: java.lang.Class[UserType]
+
+ /**
+ * The default size of a value of the UserDefinedType is 4096 bytes.
+ */
+ override def defaultSize: Int = 4096
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 6df5db4c80f34..5138942a55daa 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -244,7 +244,7 @@ class ScalaReflectionSuite extends FunSuite {
test("convert PrimitiveData to catalyst") {
val data = PrimitiveData(1, 1, 1, 1, 1, 1, true)
- val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true)
+ val convertedData = Row(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true)
val dataType = schemaFor[PrimitiveData].dataType
assert(convertToCatalyst(data, dataType) === convertedData)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala
new file mode 100644
index 0000000000000..1a0a0e6154ad2
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.Command
+import org.scalatest.FunSuite
+
+private[sql] case class TestCommand(cmd: String) extends Command
+
+private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser {
+ protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST")
+
+ override protected lazy val start: Parser[LogicalPlan] = set
+
+ private lazy val set: Parser[LogicalPlan] =
+ EXECUTE ~> ident ^^ {
+ case fileName => TestCommand(fileName)
+ }
+}
+
+private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser {
+ protected val EXECUTE = Keyword("EXECUTE")
+
+ override protected lazy val start: Parser[LogicalPlan] = set
+
+ private lazy val set: Parser[LogicalPlan] =
+ EXECUTE ~> ident ^^ {
+ case fileName => TestCommand(fileName)
+ }
+}
+
+class SqlParserSuite extends FunSuite {
+
+ test("test long keyword") {
+ val parser = new SuperLongKeywordTestParser
+ assert(TestCommand("NotRealCommand") === parser("ThisIsASuperLongKeyWordTest NotRealCommand"))
+ }
+
+ test("test case insensitive") {
+ val parser = new CaseInsensitiveTestParser
+ assert(TestCommand("NotRealCommand") === parser("EXECUTE NotRealCommand"))
+ assert(TestCommand("NotRealCommand") === parser("execute NotRealCommand"))
+ assert(TestCommand("NotRealCommand") === parser("exEcute NotRealCommand"))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index 892195f46ea24..c147be9f6b1ae 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -62,6 +62,7 @@ class DataTypeSuite extends FunSuite {
}
}
+ checkDataTypeJsonRepr(NullType)
checkDataTypeJsonRepr(BooleanType)
checkDataTypeJsonRepr(ByteType)
checkDataTypeJsonRepr(ShortType)
@@ -69,7 +70,9 @@ class DataTypeSuite extends FunSuite {
checkDataTypeJsonRepr(LongType)
checkDataTypeJsonRepr(FloatType)
checkDataTypeJsonRepr(DoubleType)
+ checkDataTypeJsonRepr(DecimalType(10, 5))
checkDataTypeJsonRepr(DecimalType.Unlimited)
+ checkDataTypeJsonRepr(DateType)
checkDataTypeJsonRepr(TimestampType)
checkDataTypeJsonRepr(StringType)
checkDataTypeJsonRepr(BinaryType)
@@ -77,12 +80,39 @@ class DataTypeSuite extends FunSuite {
checkDataTypeJsonRepr(ArrayType(StringType, false))
checkDataTypeJsonRepr(MapType(IntegerType, StringType, true))
checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false))
+
val metadata = new MetadataBuilder()
.putString("name", "age")
.build()
- checkDataTypeJsonRepr(
- StructType(Seq(
- StructField("a", IntegerType, nullable = true),
- StructField("b", ArrayType(DoubleType), nullable = false),
- StructField("c", DoubleType, nullable = false, metadata))))
+ val structType = StructType(Seq(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", ArrayType(DoubleType), nullable = false),
+ StructField("c", DoubleType, nullable = false, metadata)))
+ checkDataTypeJsonRepr(structType)
+
+ def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = {
+ test(s"Check the default size of ${dataType}") {
+ assert(dataType.defaultSize === expectedDefaultSize)
+ }
+ }
+
+ checkDefaultSize(NullType, 1)
+ checkDefaultSize(BooleanType, 1)
+ checkDefaultSize(ByteType, 1)
+ checkDefaultSize(ShortType, 2)
+ checkDefaultSize(IntegerType, 4)
+ checkDefaultSize(LongType, 8)
+ checkDefaultSize(FloatType, 4)
+ checkDefaultSize(DoubleType, 8)
+ checkDefaultSize(DecimalType(10, 5), 4096)
+ checkDefaultSize(DecimalType.Unlimited, 4096)
+ checkDefaultSize(DateType, 8)
+ checkDefaultSize(TimestampType, 8)
+ checkDefaultSize(StringType, 4096)
+ checkDefaultSize(BinaryType, 4096)
+ checkDefaultSize(ArrayType(DoubleType, true), 800)
+ checkDefaultSize(ArrayType(StringType, false), 409600)
+ checkDefaultSize(MapType(IntegerType, StringType, true), 410000)
+ checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400)
+ checkDefaultSize(structType, 812)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index f23cb18c92d5d..0a22968cc7807 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -107,7 +107,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
protected[sql] def parseSql(sql: String): LogicalPlan = {
- ddlParser(sql).getOrElse(sqlParser(sql))
+ ddlParser(sql, false).getOrElse(sqlParser(sql))
}
protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index ae4d8ba90c5bd..d1e21dffeb8c5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -330,25 +330,6 @@ class SchemaRDD(
sqlContext,
Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan))
- /**
- * :: Experimental ::
- * Filters tuples using a function over a `Dynamic` version of a given Row. DynamicRows use
- * scala's Dynamic trait to emulate an ORM of in a dynamically typed language. Since the type of
- * the column is not known at compile time, all attributes are converted to strings before
- * being passed to the function.
- *
- * {{{
- * schemaRDD.where(r => r.firstName == "Bob" && r.lastName == "Smith")
- * }}}
- *
- * @group Query
- */
- @Experimental
- def where(dynamicUdf: (DynamicRow) => Boolean) =
- new SchemaRDD(
- sqlContext,
- Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan))
-
/**
* :: Experimental ::
* Returns a sampled version of the underlying dataset.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala
index f10ee7b66feb7..f1a4053b79113 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql
+
import scala.util.parsing.combinator.RegexParsers
-import org.apache.spark.sql.catalyst.{SqlLexical, AbstractSparkSQLParser}
+import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{UncacheTableCommand, CacheTableCommand, SetCommand}
@@ -61,18 +62,6 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr
protected val TABLE = Keyword("TABLE")
protected val UNCACHE = Keyword("UNCACHE")
- protected implicit def asParser(k: Keyword): Parser[String] =
- lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
-
- private val reservedWords: Seq[String] =
- this
- .getClass
- .getMethods
- .filter(_.getReturnType == classOf[Keyword])
- .map(_.invoke(this).asInstanceOf[Keyword].str)
-
- override val lexical = new SqlLexical(reservedWords)
-
override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | others
private lazy val cache: Parser[LogicalPlan] =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 065fae3c83df1..11d5943fb427f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -21,7 +21,6 @@ import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
@@ -128,8 +127,7 @@ private[sql] case class InMemoryRelation(
rowCount += 1
}
- val stats = Row.fromSeq(
- columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _))
+ val stats = Row.merge(columnBuilders.map(_.columnStats.collectedStatistics) : _*)
batchStats += stats
CachedBatch(columnBuilders.map(_.build().array()), stats)
@@ -271,9 +269,10 @@ private[sql] case class InMemoryColumnarTableScan(
// Extract rows via column accessors
new Iterator[Row] {
+ private[this] val rowLen = nextRow.length
override def next() = {
var i = 0
- while (i < nextRow.length) {
+ while (i < rowLen) {
columnAccessors(i).extractTo(nextRow, i)
i += 1
}
@@ -297,7 +296,7 @@ private[sql] case class InMemoryColumnarTableScan(
cachedBatchIterator.filter { cachedBatch =>
if (!partitionFilter(cachedBatch.stats)) {
def statsString = relation.partitionStatistics.schema
- .zip(cachedBatch.stats)
+ .zip(cachedBatch.stats.toSeq)
.map { case (a, s) => s"${a.name}: $s" }
.mkString(", ")
logInfo(s"Skipping partition based on stats $statsString")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
index 64673248394c6..68a5b1de7691b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
@@ -127,7 +127,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
while (from.hasRemaining) {
columnType.extract(from, value, 0)
- if (value.head == currentValue.head) {
+ if (value(0) == currentValue(0)) {
currentRun += 1
} else {
// Writes current run
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 46245cd5a1869..4d7e338e8ed13 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -144,7 +144,7 @@ package object debug {
case (null, _) =>
case (row: Row, StructType(fields)) =>
- row.zip(fields.map(_.dataType)).foreach { case(d,t) => typeCheck(d,t) }
+ row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) }
case (s: Seq[_], ArrayType(elemType, _)) =>
s.foreach(typeCheck(_, elemType))
case (m: Map[_, _], MapType(keyType, valueType, _)) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 7ed64aad10d4e..b85021acc9d4c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -116,9 +116,9 @@ object EvaluatePython {
def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
case (null, _) => null
- case (row: Seq[Any], struct: StructType) =>
+ case (row: Row, struct: StructType) =>
val fields = struct.fields.map(field => field.dataType)
- row.zip(fields).map {
+ row.toSeq.zip(fields).map {
case (obj, dataType) => toJava(obj, dataType)
}.toArray
@@ -143,7 +143,8 @@ object EvaluatePython {
* Convert Row into Java Array (for pickled into Python)
*/
def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = {
- row.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray
+ // TODO: this is slow!
+ row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray
}
// Converts value to the type specified by the data type.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index db70a7eac72b9..9171939f7e8f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -458,16 +458,16 @@ private[sql] object JsonRDD extends Logging {
gen.writeEndArray()
case (MapType(kv,vv, _), v: Map[_,_]) =>
- gen.writeStartObject
+ gen.writeStartObject()
v.foreach { p =>
gen.writeFieldName(p._1.toString)
valWriter(vv,p._2)
}
- gen.writeEndObject
+ gen.writeEndObject()
- case (StructType(ty), v: Seq[_]) =>
+ case (StructType(ty), v: Row) =>
gen.writeStartObject()
- ty.zip(v).foreach {
+ ty.zip(v.toSeq).foreach {
case (_, null) =>
case (field, v) =>
gen.writeFieldName(field.name)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index b4aed04199129..9d9150246c8d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -66,7 +66,7 @@ private[sql] object CatalystConverter {
// TODO: consider using Array[T] for arrays to avoid boxing of primitive types
type ArrayScalaType[T] = Seq[T]
- type StructScalaType[T] = Seq[T]
+ type StructScalaType[T] = Row
type MapScalaType[K, V] = Map[K, V]
protected[parquet] def createConverter(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index 381298caba6f2..171b816a26332 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -18,32 +18,32 @@
package org.apache.spark.sql.sources
import scala.language.implicitConversions
-import scala.util.parsing.combinator.syntactical.StandardTokenParsers
-import scala.util.parsing.combinator.PackratParsers
import org.apache.spark.Logging
import org.apache.spark.sql.{SchemaRDD, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.SqlLexical
+import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
+
/**
* A parser for foreign DDL commands.
*/
-private[sql] class DDLParser extends StandardTokenParsers with PackratParsers with Logging {
-
- def apply(input: String): Option[LogicalPlan] = {
- phrase(ddl)(new lexical.Scanner(input)) match {
- case Success(r, x) => Some(r)
- case x =>
- logDebug(s"Not recognized as DDL: $x")
- None
+private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
+
+ def apply(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = {
+ try {
+ Some(apply(input))
+ } catch {
+ case _ if !exceptionOnError => None
+ case x: Throwable => throw x
}
}
def parseType(input: String): DataType = {
+ lexical.initialize(reservedWords)
phrase(dataType)(new lexical.Scanner(input)) match {
case Success(r, x) => r
case x =>
@@ -51,11 +51,9 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
}
}
- protected case class Keyword(str: String)
-
- protected implicit def asParser(k: Keyword): Parser[String] =
- lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
+ // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
+ // properties via reflection the class in runtime for constructing the SqlLexical object
protected val CREATE = Keyword("CREATE")
protected val TEMPORARY = Keyword("TEMPORARY")
protected val TABLE = Keyword("TABLE")
@@ -80,17 +78,10 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
protected val MAP = Keyword("MAP")
protected val STRUCT = Keyword("STRUCT")
- // Use reflection to find the reserved words defined in this class.
- protected val reservedWords =
- this.getClass
- .getMethods
- .filter(_.getReturnType == classOf[Keyword])
- .map(_.invoke(this).asInstanceOf[Keyword].str)
-
- override val lexical = new SqlLexical(reservedWords)
-
protected lazy val ddl: Parser[LogicalPlan] = createTable
+ protected def start: Parser[LogicalPlan] = ddl
+
/**
* `CREATE [TEMPORARY] TABLE avroTable
* USING org.apache.spark.sql.avro
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 2bcfe28456997..afbfe214f1ce4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -45,28 +45,28 @@ class DslQuerySuite extends QueryTest {
test("agg") {
checkAnswer(
testData2.groupBy('a)('a, sum('b)),
- Seq((1,3),(2,3),(3,3))
+ Seq(Row(1,3), Row(2,3), Row(3,3))
)
checkAnswer(
testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)),
- 9
+ Row(9)
)
checkAnswer(
testData2.aggregate(sum('b)),
- 9
+ Row(9)
)
}
test("convert $\"attribute name\" into unresolved attribute") {
checkAnswer(
testData.where($"key" === 1).select($"value"),
- Seq(Seq("1")))
+ Row("1"))
}
test("convert Scala Symbol 'attrname into unresolved attribute") {
checkAnswer(
testData.where('key === 1).select('value),
- Seq(Seq("1")))
+ Row("1"))
}
test("select *") {
@@ -78,61 +78,61 @@ class DslQuerySuite extends QueryTest {
test("simple select") {
checkAnswer(
testData.where('key === 1).select('value),
- Seq(Seq("1")))
+ Row("1"))
}
test("select with functions") {
checkAnswer(
testData.select(sum('value), avg('value), count(1)),
- Seq(Seq(5050.0, 50.5, 100)))
+ Row(5050.0, 50.5, 100))
checkAnswer(
testData2.select('a + 'b, 'a < 'b),
Seq(
- Seq(2, false),
- Seq(3, true),
- Seq(3, false),
- Seq(4, false),
- Seq(4, false),
- Seq(5, false)))
+ Row(2, false),
+ Row(3, true),
+ Row(3, false),
+ Row(4, false),
+ Row(4, false),
+ Row(5, false)))
checkAnswer(
testData2.select(sumDistinct('a)),
- Seq(Seq(6)))
+ Row(6))
}
test("global sorting") {
checkAnswer(
testData2.orderBy('a.asc, 'b.asc),
- Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2)))
+ Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2)))
checkAnswer(
testData2.orderBy('a.asc, 'b.desc),
- Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1)))
+ Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1)))
checkAnswer(
testData2.orderBy('a.desc, 'b.desc),
- Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1)))
+ Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1)))
checkAnswer(
testData2.orderBy('a.desc, 'b.asc),
- Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
+ Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2)))
checkAnswer(
arrayData.orderBy('data.getItem(0).asc),
- arrayData.collect().sortBy(_.data(0)).toSeq)
+ arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
checkAnswer(
arrayData.orderBy('data.getItem(0).desc),
- arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
+ arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
checkAnswer(
- mapData.orderBy('data.getItem(1).asc),
- mapData.collect().sortBy(_.data(1)).toSeq)
+ arrayData.orderBy('data.getItem(1).asc),
+ arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
checkAnswer(
- mapData.orderBy('data.getItem(1).desc),
- mapData.collect().sortBy(_.data(1)).reverse.toSeq)
+ arrayData.orderBy('data.getItem(1).desc),
+ arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
}
test("partition wide sorting") {
@@ -147,19 +147,19 @@ class DslQuerySuite extends QueryTest {
// (3, 2)
checkAnswer(
testData2.sortBy('a.asc, 'b.asc),
- Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2)))
+ Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2)))
checkAnswer(
testData2.sortBy('a.asc, 'b.desc),
- Seq((1,2), (1,1), (2,1), (2,2), (3,2), (3,1)))
+ Seq(Row(1,2), Row(1,1), Row(2,1), Row(2,2), Row(3,2), Row(3,1)))
checkAnswer(
testData2.sortBy('a.desc, 'b.desc),
- Seq((2,1), (1,2), (1,1), (3,2), (3,1), (2,2)))
+ Seq(Row(2,1), Row(1,2), Row(1,1), Row(3,2), Row(3,1), Row(2,2)))
checkAnswer(
testData2.sortBy('a.desc, 'b.asc),
- Seq((2,1), (1,1), (1,2), (3,1), (3,2), (2,2)))
+ Seq(Row(2,1), Row(1,1), Row(1,2), Row(3,1), Row(3,2), Row(2,2)))
}
test("limit") {
@@ -169,11 +169,11 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
arrayData.limit(1),
- arrayData.take(1).toSeq)
+ arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
checkAnswer(
mapData.limit(1),
- mapData.take(1).toSeq)
+ mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
}
test("SPARK-3395 limit distinct") {
@@ -184,8 +184,8 @@ class DslQuerySuite extends QueryTest {
.registerTempTable("onerow")
checkAnswer(
sql("select * from onerow inner join testData2 on onerow.a = testData2.a"),
- (1, 1, 1, 1) ::
- (1, 1, 1, 2) :: Nil)
+ Row(1, 1, 1, 1) ::
+ Row(1, 1, 1, 2) :: Nil)
}
test("SPARK-3858 generator qualifiers are discarded") {
@@ -193,55 +193,55 @@ class DslQuerySuite extends QueryTest {
arrayData.as('ad)
.generate(Explode("data" :: Nil, 'data), alias = Some("ex"))
.select("ex.data".attr),
- Seq(1, 2, 3, 2, 3, 4).map(Seq(_)))
+ Seq(1, 2, 3, 2, 3, 4).map(Row(_)))
}
test("average") {
checkAnswer(
testData2.aggregate(avg('a)),
- 2.0)
+ Row(2.0))
checkAnswer(
testData2.aggregate(avg('a), sumDistinct('a)), // non-partial
- (2.0, 6.0) :: Nil)
+ Row(2.0, 6.0) :: Nil)
checkAnswer(
decimalData.aggregate(avg('a)),
- new java.math.BigDecimal(2.0))
+ Row(new java.math.BigDecimal(2.0)))
checkAnswer(
decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial
- (new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
+ Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
checkAnswer(
decimalData.aggregate(avg('a cast DecimalType(10, 2))),
- new java.math.BigDecimal(2.0))
+ Row(new java.math.BigDecimal(2.0)))
checkAnswer(
decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial
- (new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
+ Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
}
test("null average") {
checkAnswer(
testData3.aggregate(avg('b)),
- 2.0)
+ Row(2.0))
checkAnswer(
testData3.aggregate(avg('b), countDistinct('b)),
- (2.0, 1) :: Nil)
+ Row(2.0, 1))
checkAnswer(
testData3.aggregate(avg('b), sumDistinct('b)), // non-partial
- (2.0, 2.0) :: Nil)
+ Row(2.0, 2.0))
}
test("zero average") {
checkAnswer(
emptyTableData.aggregate(avg('a)),
- null)
+ Row(null))
checkAnswer(
emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial
- (null, null) :: Nil)
+ Row(null, null))
}
test("count") {
@@ -249,28 +249,28 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
testData2.aggregate(count('a), sumDistinct('a)), // non-partial
- (6, 6.0) :: Nil)
+ Row(6, 6.0))
}
test("null count") {
checkAnswer(
testData3.groupBy('a)('a, count('b)),
- Seq((1,0), (2, 1))
+ Seq(Row(1,0), Row(2, 1))
)
checkAnswer(
testData3.groupBy('a)('a, count('a + 'b)),
- Seq((1,0), (2, 1))
+ Seq(Row(1,0), Row(2, 1))
)
checkAnswer(
testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
- (2, 1, 2, 2, 1) :: Nil
+ Row(2, 1, 2, 2, 1)
)
checkAnswer(
testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial
- (1, 1, 2) :: Nil
+ Row(1, 1, 2)
)
}
@@ -279,28 +279,28 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial
- (0, null) :: Nil)
+ Row(0, null))
}
test("zero sum") {
checkAnswer(
emptyTableData.aggregate(sum('a)),
- null)
+ Row(null))
}
test("zero sum distinct") {
checkAnswer(
emptyTableData.aggregate(sumDistinct('a)),
- null)
+ Row(null))
}
test("except") {
checkAnswer(
lowerCaseData.except(upperCaseData),
- (1, "a") ::
- (2, "b") ::
- (3, "c") ::
- (4, "d") :: Nil)
+ Row(1, "a") ::
+ Row(2, "b") ::
+ Row(3, "c") ::
+ Row(4, "d") :: Nil)
checkAnswer(lowerCaseData.except(lowerCaseData), Nil)
checkAnswer(upperCaseData.except(upperCaseData), Nil)
}
@@ -308,10 +308,10 @@ class DslQuerySuite extends QueryTest {
test("intersect") {
checkAnswer(
lowerCaseData.intersect(lowerCaseData),
- (1, "a") ::
- (2, "b") ::
- (3, "c") ::
- (4, "d") :: Nil)
+ Row(1, "a") ::
+ Row(2, "b") ::
+ Row(3, "c") ::
+ Row(4, "d") :: Nil)
checkAnswer(lowerCaseData.intersect(upperCaseData), Nil)
}
@@ -321,75 +321,75 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
// SELECT *, foo(key, value) FROM testData
testData.select(Star(None), foo.call('key, 'value)).limit(3),
- (1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil
+ Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil
)
}
test("sqrt") {
checkAnswer(
testData.select(sqrt('key)).orderBy('key asc),
- (1 to 100).map(n => Seq(math.sqrt(n)))
+ (1 to 100).map(n => Row(math.sqrt(n)))
)
checkAnswer(
testData.select(sqrt('value), 'key).orderBy('key asc, 'value asc),
- (1 to 100).map(n => Seq(math.sqrt(n), n))
+ (1 to 100).map(n => Row(math.sqrt(n), n))
)
checkAnswer(
testData.select(sqrt(Literal(null))),
- (1 to 100).map(_ => Seq(null))
+ (1 to 100).map(_ => Row(null))
)
}
test("abs") {
checkAnswer(
testData.select(abs('key)).orderBy('key asc),
- (1 to 100).map(n => Seq(n))
+ (1 to 100).map(n => Row(n))
)
checkAnswer(
negativeData.select(abs('key)).orderBy('key desc),
- (1 to 100).map(n => Seq(n))
+ (1 to 100).map(n => Row(n))
)
checkAnswer(
testData.select(abs(Literal(null))),
- (1 to 100).map(_ => Seq(null))
+ (1 to 100).map(_ => Row(null))
)
}
test("upper") {
checkAnswer(
lowerCaseData.select(upper('l)),
- ('a' to 'd').map(c => Seq(c.toString.toUpperCase()))
+ ('a' to 'd').map(c => Row(c.toString.toUpperCase()))
)
checkAnswer(
testData.select(upper('value), 'key),
- (1 to 100).map(n => Seq(n.toString, n))
+ (1 to 100).map(n => Row(n.toString, n))
)
checkAnswer(
testData.select(upper(Literal(null))),
- (1 to 100).map(n => Seq(null))
+ (1 to 100).map(n => Row(null))
)
}
test("lower") {
checkAnswer(
upperCaseData.select(lower('L)),
- ('A' to 'F').map(c => Seq(c.toString.toLowerCase()))
+ ('A' to 'F').map(c => Row(c.toString.toLowerCase()))
)
checkAnswer(
testData.select(lower('value), 'key),
- (1 to 100).map(n => Seq(n.toString, n))
+ (1 to 100).map(n => Row(n.toString, n))
)
checkAnswer(
testData.select(lower(Literal(null))),
- (1 to 100).map(n => Seq(null))
+ (1 to 100).map(n => Row(null))
)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index e5ab16f9dd661..cd36da7751e83 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -117,10 +117,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
checkAnswer(
upperCaseData.join(lowerCaseData, Inner).where('n === 'N),
Seq(
- (1, "A", 1, "a"),
- (2, "B", 2, "b"),
- (3, "C", 3, "c"),
- (4, "D", 4, "d")
+ Row(1, "A", 1, "a"),
+ Row(2, "B", 2, "b"),
+ Row(3, "C", 3, "c"),
+ Row(4, "D", 4, "d")
))
}
@@ -128,10 +128,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
checkAnswer(
upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)),
Seq(
- (1, "A", 1, "a"),
- (2, "B", 2, "b"),
- (3, "C", 3, "c"),
- (4, "D", 4, "d")
+ Row(1, "A", 1, "a"),
+ Row(2, "B", 2, "b"),
+ Row(3, "C", 3, "c"),
+ Row(4, "D", 4, "d")
))
}
@@ -140,10 +140,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
val y = testData2.where('a === 1).as('y)
checkAnswer(
x.join(y).where("x.a".attr === "y.a".attr),
- (1,1,1,1) ::
- (1,1,1,2) ::
- (1,2,1,1) ::
- (1,2,1,2) :: Nil
+ Row(1,1,1,1) ::
+ Row(1,1,1,2) ::
+ Row(1,2,1,1) ::
+ Row(1,2,1,2) :: Nil
)
}
@@ -163,54 +163,54 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
checkAnswer(
bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr),
testData.flatMap(
- row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq)
+ row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}
test("cartisian product join") {
checkAnswer(
testData3.join(testData3),
- (1, null, 1, null) ::
- (1, null, 2, 2) ::
- (2, 2, 1, null) ::
- (2, 2, 2, 2) :: Nil)
+ Row(1, null, 1, null) ::
+ Row(1, null, 2, 2) ::
+ Row(2, 2, 1, null) ::
+ Row(2, 2, 2, 2) :: Nil)
}
test("left outer join") {
checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)),
- (1, "A", 1, "a") ::
- (2, "B", 2, "b") ::
- (3, "C", 3, "c") ::
- (4, "D", 4, "d") ::
- (5, "E", null, null) ::
- (6, "F", null, null) :: Nil)
+ Row(1, "A", 1, "a") ::
+ Row(2, "B", 2, "b") ::
+ Row(3, "C", 3, "c") ::
+ Row(4, "D", 4, "d") ::
+ Row(5, "E", null, null) ::
+ Row(6, "F", null, null) :: Nil)
checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)),
- (1, "A", null, null) ::
- (2, "B", 2, "b") ::
- (3, "C", 3, "c") ::
- (4, "D", 4, "d") ::
- (5, "E", null, null) ::
- (6, "F", null, null) :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", 2, "b") ::
+ Row(3, "C", 3, "c") ::
+ Row(4, "D", 4, "d") ::
+ Row(5, "E", null, null) ::
+ Row(6, "F", null, null) :: Nil)
checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)),
- (1, "A", null, null) ::
- (2, "B", 2, "b") ::
- (3, "C", 3, "c") ::
- (4, "D", 4, "d") ::
- (5, "E", null, null) ::
- (6, "F", null, null) :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", 2, "b") ::
+ Row(3, "C", 3, "c") ::
+ Row(4, "D", 4, "d") ::
+ Row(5, "E", null, null) ::
+ Row(6, "F", null, null) :: Nil)
checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)),
- (1, "A", 1, "a") ::
- (2, "B", 2, "b") ::
- (3, "C", 3, "c") ::
- (4, "D", 4, "d") ::
- (5, "E", null, null) ::
- (6, "F", null, null) :: Nil)
+ Row(1, "A", 1, "a") ::
+ Row(2, "B", 2, "b") ::
+ Row(3, "C", 3, "c") ::
+ Row(4, "D", 4, "d") ::
+ Row(5, "E", null, null) ::
+ Row(6, "F", null, null) :: Nil)
// Make sure we are choosing left.outputPartitioning as the
// outputPartitioning for the outer join operator.
@@ -221,12 +221,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
|GROUP BY l.N
""".stripMargin),
- (1, 1) ::
- (2, 1) ::
- (3, 1) ::
- (4, 1) ::
- (5, 1) ::
- (6, 1) :: Nil)
+ Row(1, 1) ::
+ Row(2, 1) ::
+ Row(3, 1) ::
+ Row(4, 1) ::
+ Row(5, 1) ::
+ Row(6, 1) :: Nil)
checkAnswer(
sql(
@@ -235,42 +235,42 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
|GROUP BY r.a
""".stripMargin),
- (null, 6) :: Nil)
+ Row(null, 6) :: Nil)
}
test("right outer join") {
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)),
- (1, "a", 1, "A") ::
- (2, "b", 2, "B") ::
- (3, "c", 3, "C") ::
- (4, "d", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "a", 1, "A") ::
+ Row(2, "b", 2, "B") ::
+ Row(3, "c", 3, "C") ::
+ Row(4, "d", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)),
- (null, null, 1, "A") ::
- (2, "b", 2, "B") ::
- (3, "c", 3, "C") ::
- (4, "d", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(null, null, 1, "A") ::
+ Row(2, "b", 2, "B") ::
+ Row(3, "c", 3, "C") ::
+ Row(4, "d", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)),
- (null, null, 1, "A") ::
- (2, "b", 2, "B") ::
- (3, "c", 3, "C") ::
- (4, "d", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(null, null, 1, "A") ::
+ Row(2, "b", 2, "B") ::
+ Row(3, "c", 3, "C") ::
+ Row(4, "d", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
checkAnswer(
lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)),
- (1, "a", 1, "A") ::
- (2, "b", 2, "B") ::
- (3, "c", 3, "C") ::
- (4, "d", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "a", 1, "A") ::
+ Row(2, "b", 2, "B") ::
+ Row(3, "c", 3, "C") ::
+ Row(4, "d", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
// Make sure we are choosing right.outputPartitioning as the
// outputPartitioning for the outer join operator.
@@ -281,7 +281,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
|GROUP BY l.a
""".stripMargin),
- (null, 6) :: Nil)
+ Row(null, 6))
checkAnswer(
sql(
@@ -290,12 +290,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
|GROUP BY r.N
""".stripMargin),
- (1, 1) ::
- (2, 1) ::
- (3, 1) ::
- (4, 1) ::
- (5, 1) ::
- (6, 1) :: Nil)
+ Row(1, 1) ::
+ Row(2, 1) ::
+ Row(3, 1) ::
+ Row(4, 1) ::
+ Row(5, 1) ::
+ Row(6, 1) :: Nil)
}
test("full outer join") {
@@ -307,32 +307,32 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
checkAnswer(
left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
- (1, "A", null, null) ::
- (2, "B", null, null) ::
- (3, "C", 3, "C") ::
- (4, "D", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", null, null) ::
+ Row(3, "C", 3, "C") ::
+ Row(4, "D", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
checkAnswer(
left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))),
- (1, "A", null, null) ::
- (2, "B", null, null) ::
- (3, "C", null, null) ::
- (null, null, 3, "C") ::
- (4, "D", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", null, null) ::
+ Row(3, "C", null, null) ::
+ Row(null, null, 3, "C") ::
+ Row(4, "D", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
checkAnswer(
left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))),
- (1, "A", null, null) ::
- (2, "B", null, null) ::
- (3, "C", null, null) ::
- (null, null, 3, "C") ::
- (4, "D", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", null, null) ::
+ Row(3, "C", null, null) ::
+ Row(null, null, 3, "C") ::
+ Row(4, "D", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
// Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator.
checkAnswer(
@@ -342,7 +342,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
|GROUP BY l.a
""".stripMargin),
- (null, 10) :: Nil)
+ Row(null, 10))
checkAnswer(
sql(
@@ -351,13 +351,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
|GROUP BY r.N
""".stripMargin),
- (1, 1) ::
- (2, 1) ::
- (3, 1) ::
- (4, 1) ::
- (5, 1) ::
- (6, 1) ::
- (null, 4) :: Nil)
+ Row(1, 1) ::
+ Row(2, 1) ::
+ Row(3, 1) ::
+ Row(4, 1) ::
+ Row(5, 1) ::
+ Row(6, 1) ::
+ Row(null, 4) :: Nil)
checkAnswer(
sql(
@@ -366,13 +366,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
|GROUP BY l.N
""".stripMargin),
- (1, 1) ::
- (2, 1) ::
- (3, 1) ::
- (4, 1) ::
- (5, 1) ::
- (6, 1) ::
- (null, 4) :: Nil)
+ Row(1, 1) ::
+ Row(2, 1) ::
+ Row(3, 1) ::
+ Row(4, 1) ::
+ Row(5, 1) ::
+ Row(6, 1) ::
+ Row(null, 4) :: Nil)
checkAnswer(
sql(
@@ -381,7 +381,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
|GROUP BY r.a
""".stripMargin),
- (null, 10) :: Nil)
+ Row(null, 10))
}
test("broadcasted left semi join operator selection") {
@@ -412,12 +412,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("left semi join") {
val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
checkAnswer(rdd,
- (1, 1) ::
- (1, 2) ::
- (2, 1) ::
- (2, 2) ::
- (3, 1) ::
- (3, 2) :: Nil)
+ Row(1, 1) ::
+ Row(1, 2) ::
+ Row(2, 1) ::
+ Row(2, 2) ::
+ Row(3, 1) ::
+ Row(3, 2) :: Nil)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 68ddecc7f610d..42a21c148df53 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -47,26 +47,17 @@ class QueryTest extends PlanTest {
* @param rdd the [[SchemaRDD]] to be executed
* @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
*/
- protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Any): Unit = {
- val convertedAnswer = expectedAnswer match {
- case s: Seq[_] if s.isEmpty => s
- case s: Seq[_] if s.head.isInstanceOf[Product] &&
- !s.head.isInstanceOf[Seq[_]] => s.map(_.asInstanceOf[Product].productIterator.toIndexedSeq)
- case s: Seq[_] => s
- case singleItem => Seq(Seq(singleItem))
- }
-
+ protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = {
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
- def prepareAnswer(answer: Seq[Any]): Seq[Any] = {
+ def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
// Java's java.math.BigDecimal.compareTo).
- val converted = answer.map {
- case s: Seq[_] => s.map {
+ val converted: Seq[Row] = answer.map { s =>
+ Row.fromSeq(s.toSeq.map {
case d: java.math.BigDecimal => BigDecimal(d)
case o => o
- }
- case o => o
+ })
}
if (!isSorted) converted.sortBy(_.toString) else converted
}
@@ -82,7 +73,7 @@ class QueryTest extends PlanTest {
""".stripMargin)
}
- if (prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) {
+ if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
fail(s"""
|Results do not match for query:
|${rdd.logicalPlan}
@@ -92,15 +83,19 @@ class QueryTest extends PlanTest {
|${rdd.queryExecution.executedPlan}
|== Results ==
|${sideBySide(
- s"== Correct Answer - ${convertedAnswer.size} ==" +:
- prepareAnswer(convertedAnswer).map(_.toString),
+ s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer).map(_.toString),
s"== Spark Answer - ${sparkAnswer.size} ==" +:
prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
""".stripMargin)
}
}
- def sqlTest(sqlString: String, expectedAnswer: Any)(implicit sqlContext: SQLContext): Unit = {
+ protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = {
+ checkAnswer(rdd, Seq(expectedAnswer))
+ }
+
+ def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
test(sqlString) {
checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 54fabc5c915fb..03b44ca1d6695 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -46,7 +46,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") {
checkAnswer(
sql("SELECT a FROM testData2 SORT BY a"),
- Seq(1, 1, 2 ,2 ,3 ,3).map(Seq(_))
+ Seq(1, 1, 2 ,2 ,3 ,3).map(Row(_))
)
}
@@ -70,13 +70,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("SPARK-3176 Added Parser of SQL ABS()") {
checkAnswer(
sql("SELECT ABS(-1.3)"),
- 1.3)
+ Row(1.3))
checkAnswer(
sql("SELECT ABS(0.0)"),
- 0.0)
+ Row(0.0))
checkAnswer(
sql("SELECT ABS(2.5)"),
- 2.5)
+ Row(2.5))
}
test("aggregation with codegen") {
@@ -89,13 +89,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("SPARK-3176 Added Parser of SQL LAST()") {
checkAnswer(
sql("SELECT LAST(n) FROM lowerCaseData"),
- 4)
+ Row(4))
}
test("SPARK-2041 column name equals tablename") {
checkAnswer(
sql("SELECT tableName FROM tableName"),
- "test")
+ Row("test"))
}
test("SQRT") {
@@ -115,40 +115,40 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("SPARK-2407 Added Parser of SQL SUBSTR()") {
checkAnswer(
sql("SELECT substr(tableName, 1, 2) FROM tableName"),
- "te")
+ Row("te"))
checkAnswer(
sql("SELECT substr(tableName, 3) FROM tableName"),
- "st")
+ Row("st"))
checkAnswer(
sql("SELECT substring(tableName, 1, 2) FROM tableName"),
- "te")
+ Row("te"))
checkAnswer(
sql("SELECT substring(tableName, 3) FROM tableName"),
- "st")
+ Row("st"))
}
test("SPARK-3173 Timestamp support in the parser") {
checkAnswer(sql(
"SELECT time FROM timestamps WHERE time=CAST('1970-01-01 00:00:00.001' AS TIMESTAMP)"),
- Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))))
+ Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))
checkAnswer(sql(
"SELECT time FROM timestamps WHERE time='1970-01-01 00:00:00.001'"),
- Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))))
+ Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))
checkAnswer(sql(
"SELECT time FROM timestamps WHERE '1970-01-01 00:00:00.001'=time"),
- Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))))
+ Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))
checkAnswer(sql(
"""SELECT time FROM timestamps WHERE time<'1970-01-01 00:00:00.003'
AND time>'1970-01-01 00:00:00.001'"""),
- Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002"))))
+ Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))
checkAnswer(sql(
"SELECT time FROM timestamps WHERE time IN ('1970-01-01 00:00:00.001','1970-01-01 00:00:00.002')"),
- Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")),
- Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002"))))
+ Seq(Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")),
+ Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002"))))
checkAnswer(sql(
"SELECT time FROM timestamps WHERE time='123'"),
@@ -158,13 +158,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("index into array") {
checkAnswer(
sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"),
- arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq)
+ arrayData.map(d => Row(d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect())
}
test("left semi greater than predicate") {
checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
- Seq((3,1), (3,2))
+ Seq(Row(3,1), Row(3,2))
)
}
@@ -173,7 +173,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql(
"SELECT nestedData, nestedData[0][0], nestedData[0][0] + nestedData[0][1] FROM arrayData"),
arrayData.map(d =>
- (d.nestedData,
+ Row(d.nestedData,
d.nestedData(0)(0),
d.nestedData(0)(0) + d.nestedData(0)(1))).collect().toSeq)
}
@@ -181,13 +181,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("agg") {
checkAnswer(
sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"),
- Seq((1,3),(2,3),(3,3)))
+ Seq(Row(1,3), Row(2,3), Row(3,3)))
}
test("aggregates with nulls") {
checkAnswer(
sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"),
- (1, 3, 2, 6, 3) :: Nil
+ Row(1, 3, 2, 6, 3)
)
}
@@ -200,29 +200,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("simple select") {
checkAnswer(
sql("SELECT value FROM testData WHERE key = 1"),
- Seq(Seq("1")))
+ Row("1"))
}
def sortTest() = {
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"),
- Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2)))
+ Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2)))
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"),
- Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1)))
+ Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1)))
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"),
- Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1)))
+ Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1)))
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"),
- Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
+ Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2)))
checkAnswer(
sql("SELECT b FROM binaryData ORDER BY a ASC"),
- (1 to 5).map(Row(_)).toSeq)
+ (1 to 5).map(Row(_)))
checkAnswer(
sql("SELECT b FROM binaryData ORDER BY a DESC"),
@@ -230,19 +230,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(
sql("SELECT * FROM arrayData ORDER BY data[0] ASC"),
- arrayData.collect().sortBy(_.data(0)).toSeq)
+ arrayData.collect().sortBy(_.data(0)).map(Row.fromTuple).toSeq)
checkAnswer(
sql("SELECT * FROM arrayData ORDER BY data[0] DESC"),
- arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
+ arrayData.collect().sortBy(_.data(0)).reverse.map(Row.fromTuple).toSeq)
checkAnswer(
sql("SELECT * FROM mapData ORDER BY data[1] ASC"),
- mapData.collect().sortBy(_.data(1)).toSeq)
+ mapData.collect().sortBy(_.data(1)).map(Row.fromTuple).toSeq)
checkAnswer(
sql("SELECT * FROM mapData ORDER BY data[1] DESC"),
- mapData.collect().sortBy(_.data(1)).reverse.toSeq)
+ mapData.collect().sortBy(_.data(1)).reverse.map(Row.fromTuple).toSeq)
}
test("sorting") {
@@ -266,94 +266,94 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(
sql("SELECT * FROM arrayData LIMIT 1"),
- arrayData.collect().take(1).toSeq)
+ arrayData.collect().take(1).map(Row.fromTuple).toSeq)
checkAnswer(
sql("SELECT * FROM mapData LIMIT 1"),
- mapData.collect().take(1).toSeq)
+ mapData.collect().take(1).map(Row.fromTuple).toSeq)
}
test("from follow multiple brackets") {
checkAnswer(sql(
"select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1"),
- 1
+ Row(1)
)
checkAnswer(sql(
"select key from (select * from testData) x limit 1"),
- 1
+ Row(1)
)
checkAnswer(sql(
"select key from (select * from testData limit 1 union all select * from testData limit 1) x limit 1"),
- 1
+ Row(1)
)
}
test("average") {
checkAnswer(
sql("SELECT AVG(a) FROM testData2"),
- 2.0)
+ Row(2.0))
}
test("average overflow") {
checkAnswer(
sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"),
- Seq((2147483645.0,1),(2.0,2)))
+ Seq(Row(2147483645.0,1), Row(2.0,2)))
}
test("count") {
checkAnswer(
sql("SELECT COUNT(*) FROM testData2"),
- testData2.count())
+ Row(testData2.count()))
}
test("count distinct") {
checkAnswer(
sql("SELECT COUNT(DISTINCT b) FROM testData2"),
- 2)
+ Row(2))
}
test("approximate count distinct") {
checkAnswer(
sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"),
- 3)
+ Row(3))
}
test("approximate count distinct with user provided standard deviation") {
checkAnswer(
sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"),
- 3)
+ Row(3))
}
test("null count") {
checkAnswer(
sql("SELECT a, COUNT(b) FROM testData3 GROUP BY a"),
- Seq((1, 0), (2, 1)))
+ Seq(Row(1, 0), Row(2, 1)))
checkAnswer(
sql("SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"),
- (2, 1, 2, 2, 1) :: Nil)
+ Row(2, 1, 2, 2, 1))
}
test("inner join where, one match per row") {
checkAnswer(
sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"),
Seq(
- (1, "A", 1, "a"),
- (2, "B", 2, "b"),
- (3, "C", 3, "c"),
- (4, "D", 4, "d")))
+ Row(1, "A", 1, "a"),
+ Row(2, "B", 2, "b"),
+ Row(3, "C", 3, "c"),
+ Row(4, "D", 4, "d")))
}
test("inner join ON, one match per row") {
checkAnswer(
sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON n = N"),
Seq(
- (1, "A", 1, "a"),
- (2, "B", 2, "b"),
- (3, "C", 3, "c"),
- (4, "D", 4, "d")))
+ Row(1, "A", 1, "a"),
+ Row(2, "B", 2, "b"),
+ Row(3, "C", 3, "c"),
+ Row(4, "D", 4, "d")))
}
test("inner join, where, multiple matches") {
@@ -363,10 +363,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
| (SELECT * FROM testData2 WHERE a = 1) x JOIN
| (SELECT * FROM testData2 WHERE a = 1) y
|WHERE x.a = y.a""".stripMargin),
- (1,1,1,1) ::
- (1,1,1,2) ::
- (1,2,1,1) ::
- (1,2,1,2) :: Nil)
+ Row(1,1,1,1) ::
+ Row(1,1,1,2) ::
+ Row(1,2,1,1) ::
+ Row(1,2,1,2) :: Nil)
}
test("inner join, no matches") {
@@ -397,38 +397,38 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
| SELECT * FROM testData) y
|WHERE x.key = y.key""".stripMargin),
testData.flatMap(
- row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq)
+ row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}
ignore("cartesian product join") {
checkAnswer(
testData3.join(testData3),
- (1, null, 1, null) ::
- (1, null, 2, 2) ::
- (2, 2, 1, null) ::
- (2, 2, 2, 2) :: Nil)
+ Row(1, null, 1, null) ::
+ Row(1, null, 2, 2) ::
+ Row(2, 2, 1, null) ::
+ Row(2, 2, 2, 2) :: Nil)
}
test("left outer join") {
checkAnswer(
sql("SELECT * FROM upperCaseData LEFT OUTER JOIN lowerCaseData ON n = N"),
- (1, "A", 1, "a") ::
- (2, "B", 2, "b") ::
- (3, "C", 3, "c") ::
- (4, "D", 4, "d") ::
- (5, "E", null, null) ::
- (6, "F", null, null) :: Nil)
+ Row(1, "A", 1, "a") ::
+ Row(2, "B", 2, "b") ::
+ Row(3, "C", 3, "c") ::
+ Row(4, "D", 4, "d") ::
+ Row(5, "E", null, null) ::
+ Row(6, "F", null, null) :: Nil)
}
test("right outer join") {
checkAnswer(
sql("SELECT * FROM lowerCaseData RIGHT OUTER JOIN upperCaseData ON n = N"),
- (1, "a", 1, "A") ::
- (2, "b", 2, "B") ::
- (3, "c", 3, "C") ::
- (4, "d", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "a", 1, "A") ::
+ Row(2, "b", 2, "B") ::
+ Row(3, "c", 3, "C") ::
+ Row(4, "d", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
}
test("full outer join") {
@@ -440,12 +440,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
| (SELECT * FROM upperCaseData WHERE N >= 3) rightTable
| ON leftTable.N = rightTable.N
""".stripMargin),
- (1, "A", null, null) ::
- (2, "B", null, null) ::
- (3, "C", 3, "C") ::
- (4, "D", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", null, null) ::
+ Row(3, "C", 3, "C") ::
+ Row (4, "D", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
}
test("SPARK-3349 partitioning after limit") {
@@ -457,12 +457,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
.registerTempTable("subset2")
checkAnswer(
sql("SELECT * FROM lowerCaseData INNER JOIN subset1 ON subset1.n = lowerCaseData.n"),
- (3, "c", 3) ::
- (4, "d", 4) :: Nil)
+ Row(3, "c", 3) ::
+ Row(4, "d", 4) :: Nil)
checkAnswer(
sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"),
- (1, "a", 1) ::
- (2, "b", 2) :: Nil)
+ Row(1, "a", 1) ::
+ Row(2, "b", 2) :: Nil)
}
test("mixed-case keywords") {
@@ -474,28 +474,28 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
| (sElEcT * FROM upperCaseData whERe N >= 3) rightTable
| oN leftTable.N = rightTable.N
""".stripMargin),
- (1, "A", null, null) ::
- (2, "B", null, null) ::
- (3, "C", 3, "C") ::
- (4, "D", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
+ Row(1, "A", null, null) ::
+ Row(2, "B", null, null) ::
+ Row(3, "C", 3, "C") ::
+ Row(4, "D", 4, "D") ::
+ Row(null, null, 5, "E") ::
+ Row(null, null, 6, "F") :: Nil)
}
test("select with table name as qualifier") {
checkAnswer(
sql("SELECT testData.value FROM testData WHERE testData.key = 1"),
- Seq(Seq("1")))
+ Row("1"))
}
test("inner join ON with table name as qualifier") {
checkAnswer(
sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON lowerCaseData.n = upperCaseData.N"),
Seq(
- (1, "A", 1, "a"),
- (2, "B", 2, "b"),
- (3, "C", 3, "c"),
- (4, "D", 4, "d")))
+ Row(1, "A", 1, "a"),
+ Row(2, "B", 2, "b"),
+ Row(3, "C", 3, "c"),
+ Row(4, "D", 4, "d")))
}
test("qualified select with inner join ON with table name as qualifier") {
@@ -503,72 +503,72 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT upperCaseData.N, upperCaseData.L FROM upperCaseData JOIN lowerCaseData " +
"ON lowerCaseData.n = upperCaseData.N"),
Seq(
- (1, "A"),
- (2, "B"),
- (3, "C"),
- (4, "D")))
+ Row(1, "A"),
+ Row(2, "B"),
+ Row(3, "C"),
+ Row(4, "D")))
}
test("system function upper()") {
checkAnswer(
sql("SELECT n,UPPER(l) FROM lowerCaseData"),
Seq(
- (1, "A"),
- (2, "B"),
- (3, "C"),
- (4, "D")))
+ Row(1, "A"),
+ Row(2, "B"),
+ Row(3, "C"),
+ Row(4, "D")))
checkAnswer(
sql("SELECT n, UPPER(s) FROM nullStrings"),
Seq(
- (1, "ABC"),
- (2, "ABC"),
- (3, null)))
+ Row(1, "ABC"),
+ Row(2, "ABC"),
+ Row(3, null)))
}
test("system function lower()") {
checkAnswer(
sql("SELECT N,LOWER(L) FROM upperCaseData"),
Seq(
- (1, "a"),
- (2, "b"),
- (3, "c"),
- (4, "d"),
- (5, "e"),
- (6, "f")))
+ Row(1, "a"),
+ Row(2, "b"),
+ Row(3, "c"),
+ Row(4, "d"),
+ Row(5, "e"),
+ Row(6, "f")))
checkAnswer(
sql("SELECT n, LOWER(s) FROM nullStrings"),
Seq(
- (1, "abc"),
- (2, "abc"),
- (3, null)))
+ Row(1, "abc"),
+ Row(2, "abc"),
+ Row(3, null)))
}
test("UNION") {
checkAnswer(
sql("SELECT * FROM lowerCaseData UNION SELECT * FROM upperCaseData"),
- (1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") ::
- (4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil)
+ Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") ::
+ Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil)
checkAnswer(
sql("SELECT * FROM lowerCaseData UNION SELECT * FROM lowerCaseData"),
- (1, "a") :: (2, "b") :: (3, "c") :: (4, "d") :: Nil)
+ Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Row(4, "d") :: Nil)
checkAnswer(
sql("SELECT * FROM lowerCaseData UNION ALL SELECT * FROM lowerCaseData"),
- (1, "a") :: (1, "a") :: (2, "b") :: (2, "b") :: (3, "c") :: (3, "c") ::
- (4, "d") :: (4, "d") :: Nil)
+ Row(1, "a") :: Row(1, "a") :: Row(2, "b") :: Row(2, "b") :: Row(3, "c") :: Row(3, "c") ::
+ Row(4, "d") :: Row(4, "d") :: Nil)
}
test("UNION with column mismatches") {
// Column name mismatches are allowed.
checkAnswer(
sql("SELECT n,l FROM lowerCaseData UNION SELECT N as x1, L as x2 FROM upperCaseData"),
- (1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") ::
- (4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil)
+ Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") ::
+ Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil)
// Column type mismatches are not allowed, forcing a type coercion.
checkAnswer(
sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"),
- ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Tuple1(_)))
+ ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_)))
// Column type mismatches where a coercion is not possible, in this case between integer
// and array types, trigger a TreeNodeException.
intercept[TreeNodeException[_]] {
@@ -579,10 +579,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("EXCEPT") {
checkAnswer(
sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData"),
- (1, "a") ::
- (2, "b") ::
- (3, "c") ::
- (4, "d") :: Nil)
+ Row(1, "a") ::
+ Row(2, "b") ::
+ Row(3, "c") ::
+ Row(4, "d") :: Nil)
checkAnswer(
sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData"), Nil)
checkAnswer(
@@ -592,10 +592,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("INTERSECT") {
checkAnswer(
sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM lowerCaseData"),
- (1, "a") ::
- (2, "b") ::
- (3, "c") ::
- (4, "d") :: Nil)
+ Row(1, "a") ::
+ Row(2, "b") ::
+ Row(3, "c") ::
+ Row(4, "d") :: Nil)
checkAnswer(
sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM upperCaseData"), Nil)
}
@@ -613,25 +613,25 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql(s"SET $testKey=$testVal")
checkAnswer(
sql("SET"),
- Seq(Seq(s"$testKey=$testVal"))
+ Row(s"$testKey=$testVal")
)
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
checkAnswer(
sql("set"),
Seq(
- Seq(s"$testKey=$testVal"),
- Seq(s"${testKey + testKey}=${testVal + testVal}"))
+ Row(s"$testKey=$testVal"),
+ Row(s"${testKey + testKey}=${testVal + testVal}"))
)
// "set key"
checkAnswer(
sql(s"SET $testKey"),
- Seq(Seq(s"$testKey=$testVal"))
+ Row(s"$testKey=$testVal")
)
checkAnswer(
sql(s"SET $nonexistentKey"),
- Seq(Seq(s"$nonexistentKey="))
+ Row(s"$nonexistentKey=")
)
conf.clear()
}
@@ -655,17 +655,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
schemaRDD1.registerTempTable("applySchema1")
checkAnswer(
sql("SELECT * FROM applySchema1"),
- (1, "A1", true, null) ::
- (2, "B2", false, null) ::
- (3, "C3", true, null) ::
- (4, "D4", true, 2147483644) :: Nil)
+ Row(1, "A1", true, null) ::
+ Row(2, "B2", false, null) ::
+ Row(3, "C3", true, null) ::
+ Row(4, "D4", true, 2147483644) :: Nil)
checkAnswer(
sql("SELECT f1, f4 FROM applySchema1"),
- (1, null) ::
- (2, null) ::
- (3, null) ::
- (4, 2147483644) :: Nil)
+ Row(1, null) ::
+ Row(2, null) ::
+ Row(3, null) ::
+ Row(4, 2147483644) :: Nil)
val schema2 = StructType(
StructField("f1", StructType(
@@ -685,17 +685,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
schemaRDD2.registerTempTable("applySchema2")
checkAnswer(
sql("SELECT * FROM applySchema2"),
- (Seq(1, true), Map("A1" -> null)) ::
- (Seq(2, false), Map("B2" -> null)) ::
- (Seq(3, true), Map("C3" -> null)) ::
- (Seq(4, true), Map("D4" -> 2147483644)) :: Nil)
+ Row(Row(1, true), Map("A1" -> null)) ::
+ Row(Row(2, false), Map("B2" -> null)) ::
+ Row(Row(3, true), Map("C3" -> null)) ::
+ Row(Row(4, true), Map("D4" -> 2147483644)) :: Nil)
checkAnswer(
sql("SELECT f1.f11, f2['D4'] FROM applySchema2"),
- (1, null) ::
- (2, null) ::
- (3, null) ::
- (4, 2147483644) :: Nil)
+ Row(1, null) ::
+ Row(2, null) ::
+ Row(3, null) ::
+ Row(4, 2147483644) :: Nil)
// The value of a MapType column can be a mutable map.
val rowRDD3 = unparsedStrings.map { r =>
@@ -711,26 +711,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(
sql("SELECT f1.f11, f2['D4'] FROM applySchema3"),
- (1, null) ::
- (2, null) ::
- (3, null) ::
- (4, 2147483644) :: Nil)
+ Row(1, null) ::
+ Row(2, null) ::
+ Row(3, null) ::
+ Row(4, 2147483644) :: Nil)
}
test("SPARK-3423 BETWEEN") {
checkAnswer(
sql("SELECT key, value FROM testData WHERE key BETWEEN 5 and 7"),
- Seq((5, "5"), (6, "6"), (7, "7"))
+ Seq(Row(5, "5"), Row(6, "6"), Row(7, "7"))
)
checkAnswer(
sql("SELECT key, value FROM testData WHERE key BETWEEN 7 and 7"),
- Seq((7, "7"))
+ Row(7, "7")
)
checkAnswer(
sql("SELECT key, value FROM testData WHERE key BETWEEN 9 and 7"),
- Seq()
+ Nil
)
}
@@ -738,7 +738,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
// TODO Ensure true/false string letter casing is consistent with Hive in all cases.
checkAnswer(
sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"),
- ("true", "false") :: Nil)
+ Row("true", "false"))
}
test("metadata is propagated correctly") {
@@ -768,17 +768,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("SPARK-3371 Renaming a function expression with group by gives error") {
udf.register("len", (s: String) => s.length)
checkAnswer(
- sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)
+ sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"),
+ Row(1))
}
test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") {
checkAnswer(
- sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), 1)
+ sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"),
+ Row(1))
}
test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") {
checkAnswer(
- sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1)
+ sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"),
+ Row(1))
}
test("throw errors for non-aggregate attributes with aggregation") {
@@ -808,130 +811,131 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("Test to check we can use Long.MinValue") {
checkAnswer(
- sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Long.MinValue
+ sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Row(Long.MinValue)
)
checkAnswer(
- sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"), (1 to 100).map(Row(_)).toSeq
+ sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"),
+ (1 to 100).map(Row(_)).toSeq
)
}
test("Floating point number format") {
checkAnswer(
- sql("SELECT 0.3"), 0.3
+ sql("SELECT 0.3"), Row(0.3)
)
checkAnswer(
- sql("SELECT -0.8"), -0.8
+ sql("SELECT -0.8"), Row(-0.8)
)
checkAnswer(
- sql("SELECT .5"), 0.5
+ sql("SELECT .5"), Row(0.5)
)
checkAnswer(
- sql("SELECT -.18"), -0.18
+ sql("SELECT -.18"), Row(-0.18)
)
}
test("Auto cast integer type") {
checkAnswer(
- sql(s"SELECT ${Int.MaxValue + 1L}"), Int.MaxValue + 1L
+ sql(s"SELECT ${Int.MaxValue + 1L}"), Row(Int.MaxValue + 1L)
)
checkAnswer(
- sql(s"SELECT ${Int.MinValue - 1L}"), Int.MinValue - 1L
+ sql(s"SELECT ${Int.MinValue - 1L}"), Row(Int.MinValue - 1L)
)
checkAnswer(
- sql("SELECT 9223372036854775808"), new java.math.BigDecimal("9223372036854775808")
+ sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808"))
)
checkAnswer(
- sql("SELECT -9223372036854775809"), new java.math.BigDecimal("-9223372036854775809")
+ sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809"))
)
}
test("Test to check we can apply sign to expression") {
checkAnswer(
- sql("SELECT -100"), -100
+ sql("SELECT -100"), Row(-100)
)
checkAnswer(
- sql("SELECT +230"), 230
+ sql("SELECT +230"), Row(230)
)
checkAnswer(
- sql("SELECT -5.2"), -5.2
+ sql("SELECT -5.2"), Row(-5.2)
)
checkAnswer(
- sql("SELECT +6.8"), 6.8
+ sql("SELECT +6.8"), Row(6.8)
)
checkAnswer(
- sql("SELECT -key FROM testData WHERE key = 2"), -2
+ sql("SELECT -key FROM testData WHERE key = 2"), Row(-2)
)
checkAnswer(
- sql("SELECT +key FROM testData WHERE key = 3"), 3
+ sql("SELECT +key FROM testData WHERE key = 3"), Row(3)
)
checkAnswer(
- sql("SELECT -(key + 1) FROM testData WHERE key = 1"), -2
+ sql("SELECT -(key + 1) FROM testData WHERE key = 1"), Row(-2)
)
checkAnswer(
- sql("SELECT - key + 1 FROM testData WHERE key = 10"), -9
+ sql("SELECT - key + 1 FROM testData WHERE key = 10"), Row(-9)
)
checkAnswer(
- sql("SELECT +(key + 5) FROM testData WHERE key = 5"), 10
+ sql("SELECT +(key + 5) FROM testData WHERE key = 5"), Row(10)
)
checkAnswer(
- sql("SELECT -MAX(key) FROM testData"), -100
+ sql("SELECT -MAX(key) FROM testData"), Row(-100)
)
checkAnswer(
- sql("SELECT +MAX(key) FROM testData"), 100
+ sql("SELECT +MAX(key) FROM testData"), Row(100)
)
checkAnswer(
- sql("SELECT - (-10)"), 10
+ sql("SELECT - (-10)"), Row(10)
)
checkAnswer(
- sql("SELECT + (-key) FROM testData WHERE key = 32"), -32
+ sql("SELECT + (-key) FROM testData WHERE key = 32"), Row(-32)
)
checkAnswer(
- sql("SELECT - (+Max(key)) FROM testData"), -100
+ sql("SELECT - (+Max(key)) FROM testData"), Row(-100)
)
checkAnswer(
- sql("SELECT - - 3"), 3
+ sql("SELECT - - 3"), Row(3)
)
checkAnswer(
- sql("SELECT - + 20"), -20
+ sql("SELECT - + 20"), Row(-20)
)
checkAnswer(
- sql("SELEcT - + 45"), -45
+ sql("SELEcT - + 45"), Row(-45)
)
checkAnswer(
- sql("SELECT + + 100"), 100
+ sql("SELECT + + 100"), Row(100)
)
checkAnswer(
- sql("SELECT - - Max(key) FROM testData"), 100
+ sql("SELECT - - Max(key) FROM testData"), Row(100)
)
checkAnswer(
- sql("SELECT + - key FROM testData WHERE key = 33"), -33
+ sql("SELECT + - key FROM testData WHERE key = 33"), Row(-33)
)
}
@@ -943,7 +947,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
|JOIN testData b ON a.key = b.key
|JOIN testData c ON a.key = c.key
""".stripMargin),
- (1 to 100).map(i => Seq(i, i, i)))
+ (1 to 100).map(i => Row(i, i, i)))
}
test("SPARK-3483 Special chars in column names") {
@@ -953,19 +957,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("SPARK-3814 Support Bitwise & operator") {
- checkAnswer(sql("SELECT key&1 FROM testData WHERE key = 1 "), 1)
+ checkAnswer(sql("SELECT key&1 FROM testData WHERE key = 1 "), Row(1))
}
test("SPARK-3814 Support Bitwise | operator") {
- checkAnswer(sql("SELECT key|0 FROM testData WHERE key = 1 "), 1)
+ checkAnswer(sql("SELECT key|0 FROM testData WHERE key = 1 "), Row(1))
}
test("SPARK-3814 Support Bitwise ^ operator") {
- checkAnswer(sql("SELECT key^0 FROM testData WHERE key = 1 "), 1)
+ checkAnswer(sql("SELECT key^0 FROM testData WHERE key = 1 "), Row(1))
}
test("SPARK-3814 Support Bitwise ~ operator") {
- checkAnswer(sql("SELECT ~key FROM testData WHERE key = 1 "), -2)
+ checkAnswer(sql("SELECT ~key FROM testData WHERE key = 1 "), Row(-2))
}
test("SPARK-4120 Join of multiple tables does not work in SparkSQL") {
@@ -975,40 +979,40 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
|FROM testData a,testData b,testData c
|where a.key = b.key and a.key = c.key
""".stripMargin),
- (1 to 100).map(i => Seq(i, i, i)))
+ (1 to 100).map(i => Row(i, i, i)))
}
test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") {
checkAnswer(sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"),
- (11 to 100).map(i => Seq(i)))
+ (11 to 100).map(i => Row(i)))
}
test("SPARK-4207 Query which has syntax like 'not like' is not working in Spark SQL") {
checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"),
- (1 to 99).map(i => Seq(i)))
+ (1 to 99).map(i => Row(i)))
}
test("SPARK-4322 Grouping field with struct field as sub expression") {
jsonRDD(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data")
- checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), 1)
+ checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1))
dropTempTable("data")
jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data")
- checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), 2)
+ checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2))
dropTempTable("data")
}
test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") {
checkAnswer(
sql("SELECT a + b FROM testData2 ORDER BY a"),
- Seq(2, 3, 3 ,4 ,4 ,5).map(Seq(_))
+ Seq(2, 3, 3 ,4 ,4 ,5).map(Row(_))
)
}
test("oder by asc by default when not specify ascending and descending") {
checkAnswer(
sql("SELECT a, b FROM testData2 ORDER BY a desc, b"),
- Seq((3, 1), (3, 2), (2, 1), (2,2), (1, 1), (1, 2))
+ Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2,2), Row(1, 1), Row(1, 2))
)
}
@@ -1021,13 +1025,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
rdd2.registerTempTable("nulldata2")
checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " +
"nulldata2 on nulldata1.value <=> nulldata2.value"),
- (1 to 2).map(i => Seq(i)))
+ (1 to 2).map(i => Row(i)))
}
test("Multi-column COUNT(DISTINCT ...)") {
val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil
val rdd = sparkContext.parallelize((0 to 1).map(i => data(i)))
rdd.registerTempTable("distinctData")
- checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), 2)
+ checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index ee381da491054..a015884bae282 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -82,7 +82,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
rdd.registerTempTable("reflectData")
assert(sql("SELECT * FROM reflectData").collect().head ===
- Seq("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
+ Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)))
}
@@ -91,7 +91,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
val rdd = sparkContext.parallelize(data :: Nil)
rdd.registerTempTable("reflectNullData")
- assert(sql("SELECT * FROM reflectNullData").collect().head === Seq.fill(7)(null))
+ assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null)))
}
test("query case class RDD with Nones") {
@@ -99,7 +99,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
val rdd = sparkContext.parallelize(data :: Nil)
rdd.registerTempTable("reflectOptionalData")
- assert(sql("SELECT * FROM reflectOptionalData").collect().head === Seq.fill(7)(null))
+ assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null)))
}
// Equality is broken for Arrays, so we test that separately.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 9be0b38e689ff..be2b34de077c9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -42,8 +42,8 @@ class ColumnStatsSuite extends FunSuite {
test(s"$columnStatsName: empty") {
val columnStats = columnStatsClass.newInstance()
- columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) =>
- assert(actual === expected)
+ columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach {
+ case (actual, expected) => assert(actual === expected)
}
}
@@ -54,7 +54,7 @@ class ColumnStatsSuite extends FunSuite {
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(columnStats.gatherStats(_, 0))
- val values = rows.take(10).map(_.head.asInstanceOf[T#JvmType])
+ val values = rows.take(10).map(_(0).asInstanceOf[T#JvmType])
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]]
val stats = columnStats.collectedStatistics
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index d94729ba92360..e61f3c39631da 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -49,7 +49,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
checkAnswer(scan, testData.collect().map {
case Row(key: Int, value: String) => value -> key
- }.toSeq)
+ }.map(Row.fromTuple))
}
test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
@@ -63,49 +63,49 @@ class InMemoryColumnarQuerySuite extends QueryTest {
test("SPARK-1678 regression: compression must not lose repeated values") {
checkAnswer(
sql("SELECT * FROM repeatedData"),
- repeatedData.collect().toSeq)
+ repeatedData.collect().toSeq.map(Row.fromTuple))
cacheTable("repeatedData")
checkAnswer(
sql("SELECT * FROM repeatedData"),
- repeatedData.collect().toSeq)
+ repeatedData.collect().toSeq.map(Row.fromTuple))
}
test("with null values") {
checkAnswer(
sql("SELECT * FROM nullableRepeatedData"),
- nullableRepeatedData.collect().toSeq)
+ nullableRepeatedData.collect().toSeq.map(Row.fromTuple))
cacheTable("nullableRepeatedData")
checkAnswer(
sql("SELECT * FROM nullableRepeatedData"),
- nullableRepeatedData.collect().toSeq)
+ nullableRepeatedData.collect().toSeq.map(Row.fromTuple))
}
test("SPARK-2729 regression: timestamp data type") {
checkAnswer(
sql("SELECT time FROM timestamps"),
- timestamps.collect().toSeq)
+ timestamps.collect().toSeq.map(Row.fromTuple))
cacheTable("timestamps")
checkAnswer(
sql("SELECT time FROM timestamps"),
- timestamps.collect().toSeq)
+ timestamps.collect().toSeq.map(Row.fromTuple))
}
test("SPARK-3320 regression: batched column buffer building should work with empty partitions") {
checkAnswer(
sql("SELECT * FROM withEmptyParts"),
- withEmptyParts.collect().toSeq)
+ withEmptyParts.collect().toSeq.map(Row.fromTuple))
cacheTable("withEmptyParts")
checkAnswer(
sql("SELECT * FROM withEmptyParts"),
- withEmptyParts.collect().toSeq)
+ withEmptyParts.collect().toSeq.map(Row.fromTuple))
}
test("SPARK-4182 Caching complex types") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index 592cafbbdc203..c3a3f8ddc3ebf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -108,7 +108,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
val queryExecution = schemaRdd.queryExecution
assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") {
- schemaRdd.collect().map(_.head).toArray
+ schemaRdd.collect().map(_(0)).toArray
}
val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
index d9e488e0ffd16..8b518f094174c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
@@ -34,7 +34,7 @@ class BooleanBitSetSuite extends FunSuite {
val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet)
val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN))
- val values = rows.map(_.head)
+ val values = rows.map(_(0))
rows.foreach(builder.appendFrom(_, 0))
val buffer = builder.build()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index c5b6fce5fd297..67007b8c093ca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.planner._
+import org.apache.spark.sql.types._
class PlannerSuite extends FunSuite {
test("unions are collapsed") {
@@ -60,19 +61,62 @@ class PlannerSuite extends FunSuite {
}
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
- val origThreshold = conf.autoBroadcastJoinThreshold
- setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString)
-
- // Using a threshold that is definitely larger than the small testing table (b) below
- val a = testData.as('a)
- val b = testData.limit(3).as('b)
- val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan
+ def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = {
+ setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold.toString)
+ val fields = fieldTypes.zipWithIndex.map {
+ case (dataType, index) => StructField(s"c${index}", dataType, true)
+ } :+ StructField("key", IntegerType, true)
+ val schema = StructType(fields)
+ val row = Row.fromSeq(Seq.fill(fields.size)(null))
+ val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil)
+ applySchema(rowRDD, schema).registerTempTable("testLimit")
+
+ val planned = sql(
+ """
+ |SELECT l.a, l.b
+ |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key)
+ """.stripMargin).queryExecution.executedPlan
+
+ val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
+ val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
+
+ assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
+ assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
+
+ dropTempTable("testLimit")
+ }
- val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
- val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
+ val origThreshold = conf.autoBroadcastJoinThreshold
- assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
- assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
+ val simpleTypes =
+ NullType ::
+ BooleanType ::
+ ByteType ::
+ ShortType ::
+ IntegerType ::
+ LongType ::
+ FloatType ::
+ DoubleType ::
+ DecimalType(10, 5) ::
+ DecimalType.Unlimited ::
+ DateType ::
+ TimestampType ::
+ StringType ::
+ BinaryType :: Nil
+
+ checkPlan(simpleTypes, newThreshold = 16434)
+
+ val complexTypes =
+ ArrayType(DoubleType, true) ::
+ ArrayType(StringType, false) ::
+ MapType(IntegerType, StringType, true) ::
+ MapType(IntegerType, ArrayType(DoubleType), false) ::
+ StructType(Seq(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", ArrayType(DoubleType), nullable = false),
+ StructField("c", DoubleType, nullable = false))) :: Nil
+
+ checkPlan(complexTypes, newThreshold = 901617)
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
index 2cab5e0c44d92..272c0d4cb2335 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
@@ -59,7 +59,7 @@ class TgfSuite extends QueryTest {
checkAnswer(
inputData.generate(ExampleTGF()),
Seq(
- "michael is 29 years old" :: Nil,
- "Next year, michael will be 30 years old" :: Nil))
+ Row("michael is 29 years old"),
+ Row("Next year, michael will be 30 years old")))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 2bc9aede32f2a..94d14acccbb18 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -229,13 +229,13 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable"),
- (new java.math.BigDecimal("92233720368547758070"),
- true,
- 1.7976931348623157E308,
- 10,
- 21474836470L,
- null,
- "this is a simple string.") :: Nil
+ Row(new java.math.BigDecimal("92233720368547758070"),
+ true,
+ 1.7976931348623157E308,
+ 10,
+ 21474836470L,
+ null,
+ "this is a simple string.")
)
}
@@ -271,48 +271,49 @@ class JsonSuite extends QueryTest {
// Access elements of a primitive array.
checkAnswer(
sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"),
- ("str1", "str2", null) :: Nil
+ Row("str1", "str2", null)
)
// Access an array of null values.
checkAnswer(
sql("select arrayOfNull from jsonTable"),
- Seq(Seq(null, null, null, null)) :: Nil
+ Row(Seq(null, null, null, null))
)
// Access elements of a BigInteger array (we use DecimalType internally).
checkAnswer(
sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"),
- (new java.math.BigDecimal("922337203685477580700"),
- new java.math.BigDecimal("-922337203685477580800"), null) :: Nil
+ Row(new java.math.BigDecimal("922337203685477580700"),
+ new java.math.BigDecimal("-922337203685477580800"), null)
)
// Access elements of an array of arrays.
checkAnswer(
sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"),
- (Seq("1", "2", "3"), Seq("str1", "str2")) :: Nil
+ Row(Seq("1", "2", "3"), Seq("str1", "str2"))
)
// Access elements of an array of arrays.
checkAnswer(
sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"),
- (Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) :: Nil
+ Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1))
)
// Access elements of an array inside a filed with the type of ArrayType(ArrayType).
checkAnswer(
sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"),
- ("str2", 2.1) :: Nil
+ Row("str2", 2.1)
)
// Access elements of an array of structs.
checkAnswer(
sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " +
"from jsonTable"),
- (true :: "str1" :: null :: Nil,
- false :: null :: null :: Nil,
- null :: null :: null :: Nil,
- null) :: Nil
+ Row(
+ Row(true, "str1", null),
+ Row(false, null, null),
+ Row(null, null, null),
+ null)
)
// Access a struct and fields inside of it.
@@ -327,13 +328,13 @@ class JsonSuite extends QueryTest {
// Access an array field of a struct.
checkAnswer(
sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"),
- (Seq(4, 5, 6), Seq("str1", "str2")) :: Nil
+ Row(Seq(4, 5, 6), Seq("str1", "str2"))
)
// Access elements of an array field of a struct.
checkAnswer(
sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"),
- (5, null) :: Nil
+ Row(5, null)
)
}
@@ -344,14 +345,14 @@ class JsonSuite extends QueryTest {
// Right now, "field1" and "field2" are treated as aliases. We should fix it.
checkAnswer(
sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"),
- (true, "str1") :: Nil
+ Row(true, "str1")
)
// Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2.
// Getting all values of a specific field from an array of structs.
checkAnswer(
sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"),
- (Seq(true, false), Seq("str1", null)) :: Nil
+ Row(Seq(true, false), Seq("str1", null))
)
}
@@ -372,57 +373,57 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable"),
- ("true", 11L, null, 1.1, "13.1", "str1") ::
- ("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") ::
- ("false", 21474836470L, new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") ::
- (null, 21474836570L, new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil
+ Row("true", 11L, null, 1.1, "13.1", "str1") ::
+ Row("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") ::
+ Row("false", 21474836470L, new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") ::
+ Row(null, 21474836570L, new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil
)
// Number and Boolean conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_bool - 10 from jsonTable where num_bool > 11"),
- 2
+ Row(2)
)
// Widening to LongType
checkAnswer(
sql("select num_num_1 - 100 from jsonTable where num_num_1 > 11"),
- Seq(21474836370L) :: Seq(21474836470L) :: Nil
+ Row(21474836370L) :: Row(21474836470L) :: Nil
)
checkAnswer(
sql("select num_num_1 - 100 from jsonTable where num_num_1 > 10"),
- Seq(-89) :: Seq(21474836370L) :: Seq(21474836470L) :: Nil
+ Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil
)
// Widening to DecimalType
checkAnswer(
sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"),
- Seq(new java.math.BigDecimal("21474836472.1")) :: Seq(new java.math.BigDecimal("92233720368547758071.2")) :: Nil
+ Row(new java.math.BigDecimal("21474836472.1")) :: Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil
)
// Widening to DoubleType
checkAnswer(
sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"),
- Seq(101.2) :: Seq(21474836471.2) :: Nil
+ Row(101.2) :: Row(21474836471.2) :: Nil
)
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str > 14"),
- 92233720368547758071.2
+ Row(92233720368547758071.2)
)
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"),
- new java.math.BigDecimal("92233720368547758061.2").doubleValue
+ Row(new java.math.BigDecimal("92233720368547758061.2").doubleValue)
)
// String and Boolean conflict: resolve the type as string.
checkAnswer(
sql("select * from jsonTable where str_bool = 'str1'"),
- ("true", 11L, null, 1.1, "13.1", "str1") :: Nil
+ Row("true", 11L, null, 1.1, "13.1", "str1")
)
}
@@ -434,24 +435,24 @@ class JsonSuite extends QueryTest {
// Number and Boolean conflict: resolve the type as boolean in this query.
checkAnswer(
sql("select num_bool from jsonTable where NOT num_bool"),
- false
+ Row(false)
)
checkAnswer(
sql("select str_bool from jsonTable where NOT str_bool"),
- false
+ Row(false)
)
// Right now, the analyzer does not know that num_bool should be treated as a boolean.
// Number and Boolean conflict: resolve the type as boolean in this query.
checkAnswer(
sql("select num_bool from jsonTable where num_bool"),
- true
+ Row(true)
)
checkAnswer(
sql("select str_bool from jsonTable where str_bool"),
- false
+ Row(false)
)
// The plan of the following DSL is
@@ -464,7 +465,7 @@ class JsonSuite extends QueryTest {
jsonSchemaRDD.
where('num_str > BigDecimal("92233720368547758060")).
select('num_str + 1.2 as Symbol("num")),
- new java.math.BigDecimal("92233720368547758061.2")
+ Row(new java.math.BigDecimal("92233720368547758061.2"))
)
// The following test will fail. The type of num_str is StringType.
@@ -475,7 +476,7 @@ class JsonSuite extends QueryTest {
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str > 13"),
- Seq(14.3) :: Seq(92233720368547758071.2) :: Nil
+ Row(14.3) :: Row(92233720368547758071.2) :: Nil
)
}
@@ -496,10 +497,10 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable"),
- (Seq(), "11", "[1,2,3]", Seq(null), "[]") ::
- (null, """{"field":false}""", null, null, "{}") ::
- (Seq(4, 5, 6), null, "str", Seq(null), "[7,8,9]") ::
- (Seq(7), "{}","[str1,str2,33]", Seq("str"), """{"field":true}""") :: Nil
+ Row(Seq(), "11", "[1,2,3]", Row(null), "[]") ::
+ Row(null, """{"field":false}""", null, null, "{}") ::
+ Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") ::
+ Row(Seq(7), "{}","[str1,str2,33]", Row("str"), """{"field":true}""") :: Nil
)
}
@@ -518,16 +519,16 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable"),
- Seq(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]",
- """{"field":"str"}"""), Seq(Seq(214748364700L), Seq(1)), null) ::
- Seq(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) ::
- Seq(null, null, Seq("1", "2", "3")) :: Nil
+ Row(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]",
+ """{"field":"str"}"""), Seq(Row(214748364700L), Row(1)), null) ::
+ Row(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) ::
+ Row(null, null, Seq("1", "2", "3")) :: Nil
)
// Treat an element as a number.
checkAnswer(
sql("select array1[0] + 1 from jsonTable where array1 is not null"),
- 2
+ Row(2)
)
}
@@ -568,13 +569,13 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable"),
- (new java.math.BigDecimal("92233720368547758070"),
+ Row(new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
21474836470L,
null,
- "this is a simple string.") :: Nil
+ "this is a simple string.")
)
}
@@ -594,13 +595,13 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTableSQL"),
- (new java.math.BigDecimal("92233720368547758070"),
+ Row(new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
21474836470L,
null,
- "this is a simple string.") :: Nil
+ "this is a simple string.")
)
}
@@ -626,13 +627,13 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable1"),
- (new java.math.BigDecimal("92233720368547758070"),
+ Row(new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
21474836470L,
null,
- "this is a simple string.") :: Nil
+ "this is a simple string.")
)
val jsonSchemaRDD2 = jsonRDD(primitiveFieldAndType, schema)
@@ -643,13 +644,13 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable2"),
- (new java.math.BigDecimal("92233720368547758070"),
+ Row(new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
21474836470L,
null,
- "this is a simple string.") :: Nil
+ "this is a simple string.")
)
}
@@ -659,7 +660,7 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"),
- (true, "str1") :: Nil
+ Row(true, "str1")
)
checkAnswer(
sql(
@@ -667,7 +668,7 @@ class JsonSuite extends QueryTest {
|select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1]
|from jsonTable
""".stripMargin),
- ("str2", 6) :: Nil
+ Row("str2", 6)
)
}
@@ -681,7 +682,7 @@ class JsonSuite extends QueryTest {
|select arrayOfArray1[0][0][0], arrayOfArray1[1][0][1], arrayOfArray1[1][1][0]
|from jsonTable
""".stripMargin),
- (5, 7, 8) :: Nil
+ Row(5, 7, 8)
)
checkAnswer(
sql(
@@ -690,7 +691,7 @@ class JsonSuite extends QueryTest {
|arrayOfArray2[1][1][1].inner2[0], arrayOfArray2[2][0][0].inner3[0][0].inner4
|from jsonTable
""".stripMargin),
- ("str1", Nil, "str4", 2) :: Nil
+ Row("str1", Nil, "str4", 2)
)
}
@@ -704,10 +705,10 @@ class JsonSuite extends QueryTest {
|select a, b, c
|from jsonTable
""".stripMargin),
- ("str_a_1", null, null) ::
- ("str_a_2", null, null) ::
- (null, "str_b_3", null) ::
- ("str_a_4", "str_b_4", "str_c_4") :: Nil
+ Row("str_a_1", null, null) ::
+ Row("str_a_2", null, null) ::
+ Row(null, "str_b_3", null) ::
+ Row("str_a_4", "str_b_4", "str_c_4") :: Nil
)
}
@@ -734,12 +735,12 @@ class JsonSuite extends QueryTest {
|SELECT a, b, c, _unparsed
|FROM jsonTable
""".stripMargin),
- (null, null, null, "{") ::
- (null, null, null, "") ::
- (null, null, null, """{"a":1, b:2}""") ::
- (null, null, null, """{"a":{, b:3}""") ::
- ("str_a_4", "str_b_4", "str_c_4", null) ::
- (null, null, null, "]") :: Nil
+ Row(null, null, null, "{") ::
+ Row(null, null, null, "") ::
+ Row(null, null, null, """{"a":1, b:2}""") ::
+ Row(null, null, null, """{"a":{, b:3}""") ::
+ Row("str_a_4", "str_b_4", "str_c_4", null) ::
+ Row(null, null, null, "]") :: Nil
)
checkAnswer(
@@ -749,7 +750,7 @@ class JsonSuite extends QueryTest {
|FROM jsonTable
|WHERE _unparsed IS NULL
""".stripMargin),
- ("str_a_4", "str_b_4", "str_c_4") :: Nil
+ Row("str_a_4", "str_b_4", "str_c_4")
)
checkAnswer(
@@ -759,11 +760,11 @@ class JsonSuite extends QueryTest {
|FROM jsonTable
|WHERE _unparsed IS NOT NULL
""".stripMargin),
- Seq("{") ::
- Seq("") ::
- Seq("""{"a":1, b:2}""") ::
- Seq("""{"a":{, b:3}""") ::
- Seq("]") :: Nil
+ Row("{") ::
+ Row("") ::
+ Row("""{"a":1, b:2}""") ::
+ Row("""{"a":{, b:3}""") ::
+ Row("]") :: Nil
)
TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
@@ -793,10 +794,10 @@ class JsonSuite extends QueryTest {
|SELECT field1, field2, field3, field4
|FROM jsonTable
""".stripMargin),
- Seq(Seq(Seq(null), Seq(Seq(Seq("Test")))), null, null, null) ::
- Seq(null, Seq(null, Seq(Seq(1))), null, null) ::
- Seq(null, null, Seq(Seq(null), Seq(Seq("2"))), null) ::
- Seq(null, null, null, Seq(Seq(null, Seq(1, 2, 3)))) :: Nil
+ Row(Seq(Seq(null), Seq(Seq(Seq("Test")))), null, null, null) ::
+ Row(null, Seq(null, Seq(Row(1))), null, null) ::
+ Row(null, null, Seq(Seq(null), Seq(Row("2"))), null) ::
+ Row(null, null, null, Seq(Seq(null, Seq(1, 2, 3)))) :: Nil
)
}
@@ -851,12 +852,12 @@ class JsonSuite extends QueryTest {
primTable.registerTempTable("primativeTable")
checkAnswer(
sql("select * from primativeTable"),
- (new java.math.BigDecimal("92233720368547758070"),
+ Row(new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
21474836470L,
- "this is a simple string.") :: Nil
+ "this is a simple string.")
)
val complexJsonSchemaRDD = jsonRDD(complexFieldAndType1)
@@ -865,38 +866,38 @@ class JsonSuite extends QueryTest {
// Access elements of a primitive array.
checkAnswer(
sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from complexTable"),
- ("str1", "str2", null) :: Nil
+ Row("str1", "str2", null)
)
// Access an array of null values.
checkAnswer(
sql("select arrayOfNull from complexTable"),
- Seq(Seq(null, null, null, null)) :: Nil
+ Row(Seq(null, null, null, null))
)
// Access elements of a BigInteger array (we use DecimalType internally).
checkAnswer(
sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from complexTable"),
- (new java.math.BigDecimal("922337203685477580700"),
- new java.math.BigDecimal("-922337203685477580800"), null) :: Nil
+ Row(new java.math.BigDecimal("922337203685477580700"),
+ new java.math.BigDecimal("-922337203685477580800"), null)
)
// Access elements of an array of arrays.
checkAnswer(
sql("select arrayOfArray1[0], arrayOfArray1[1] from complexTable"),
- (Seq("1", "2", "3"), Seq("str1", "str2")) :: Nil
+ Row(Seq("1", "2", "3"), Seq("str1", "str2"))
)
// Access elements of an array of arrays.
checkAnswer(
sql("select arrayOfArray2[0], arrayOfArray2[1] from complexTable"),
- (Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) :: Nil
+ Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1))
)
// Access elements of an array inside a filed with the type of ArrayType(ArrayType).
checkAnswer(
sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from complexTable"),
- ("str2", 2.1) :: Nil
+ Row("str2", 2.1)
)
// Access a struct and fields inside of it.
@@ -911,13 +912,13 @@ class JsonSuite extends QueryTest {
// Access an array field of a struct.
checkAnswer(
sql("select structWithArrayFields.field1, structWithArrayFields.field2 from complexTable"),
- (Seq(4, 5, 6), Seq("str1", "str2")) :: Nil
+ Row(Seq(4, 5, 6), Seq("str1", "str2"))
)
// Access elements of an array field of a struct.
checkAnswer(
sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from complexTable"),
- (5, null) :: Nil
+ Row(5, null)
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
index 4c3a04506ce42..1e7d3e06fc196 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
@@ -21,7 +21,7 @@ import parquet.filter2.predicate.Operators._
import parquet.filter2.predicate.{FilterPredicate, Operators}
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.{Literal, Predicate, Row}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, Predicate, Row}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD}
@@ -40,15 +40,16 @@ import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD}
class ParquetFilterSuite extends QueryTest with ParquetTest {
val sqlContext = TestSQLContext
- private def checkFilterPushdown(
+ private def checkFilterPredicate(
rdd: SchemaRDD,
- output: Seq[Symbol],
predicate: Predicate,
filterClass: Class[_ <: FilterPredicate],
- checker: (SchemaRDD, Any) => Unit,
- expectedResult: => Any): Unit = {
+ checker: (SchemaRDD, Seq[Row]) => Unit,
+ expected: Seq[Row]): Unit = {
+ val output = predicate.collect { case a: Attribute => a }.distinct
+
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") {
- val query = rdd.select(output.map(_.attr): _*).where(predicate)
+ val query = rdd.select(output: _*).where(predicate)
val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect {
case plan: ParquetTableScan => plan.columnPruningPred
@@ -58,202 +59,180 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
maybeAnalyzedPredicate.foreach { pred =>
val maybeFilter = ParquetFilters.createFilter(pred)
assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred")
- maybeFilter.foreach(f => assert(f.getClass === filterClass))
+ maybeFilter.foreach { f =>
+ // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`)
+ assert(f.getClass === filterClass)
+ }
}
- checker(query, expectedResult)
+ checker(query, expected)
}
}
- private def checkFilterPushdown
- (rdd: SchemaRDD, output: Symbol*)
- (predicate: Predicate, filterClass: Class[_ <: FilterPredicate])
- (expectedResult: => Any): Unit = {
- checkFilterPushdown(rdd, output, predicate, filterClass, checkAnswer _, expectedResult)
+ private def checkFilterPredicate
+ (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row])
+ (implicit rdd: SchemaRDD): Unit = {
+ checkFilterPredicate(rdd, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected)
}
- def checkBinaryFilterPushdown
- (rdd: SchemaRDD, output: Symbol*)
- (predicate: Predicate, filterClass: Class[_ <: FilterPredicate])
- (expectedResult: => Any): Unit = {
- def checkBinaryAnswer(rdd: SchemaRDD, result: Any): Unit = {
- val actual = rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq
- val expected = result match {
- case s: Seq[_] => s.map(_.asInstanceOf[Row].getAs[Array[Byte]](0).mkString(","))
- case s => Seq(s.asInstanceOf[Array[Byte]].mkString(","))
- }
- assert(actual.sorted === expected.sorted)
- }
- checkFilterPushdown(rdd, output, predicate, filterClass, checkBinaryAnswer _, expectedResult)
+ private def checkFilterPredicate[T]
+ (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: T)
+ (implicit rdd: SchemaRDD): Unit = {
+ checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd)
}
test("filter pushdown - boolean") {
- withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { rdd =>
- checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Boolean]])(Seq.empty[Row])
- checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Boolean]]) {
- Seq(Row(true), Row(false))
- }
+ withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit rdd =>
+ checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false)))
- checkFilterPushdown(rdd, '_1)('_1 === true, classOf[Eq[java.lang.Boolean]])(true)
- checkFilterPushdown(rdd, '_1)('_1 !== true, classOf[Operators.NotEq[java.lang.Boolean]]) {
- false
- }
+ checkFilterPredicate('_1 === true, classOf[Eq [_]], true)
+ checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false)
}
}
test("filter pushdown - integer") {
- withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { rdd =>
- checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[Integer]])(Seq.empty[Row])
- checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[Integer]]) {
- (1 to 4).map(Row.apply(_))
- }
-
- checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[Integer]])(1)
- checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[Integer]]) {
- (2 to 4).map(Row.apply(_))
- }
-
- checkFilterPushdown(rdd, '_1)('_1 < 2, classOf[Lt [Integer]])(1)
- checkFilterPushdown(rdd, '_1)('_1 > 3, classOf[Gt [Integer]])(4)
- checkFilterPushdown(rdd, '_1)('_1 <= 1, classOf[LtEq[Integer]])(1)
- checkFilterPushdown(rdd, '_1)('_1 >= 4, classOf[GtEq[Integer]])(4)
-
- checkFilterPushdown(rdd, '_1)(Literal(1) === '_1, classOf[Eq [Integer]])(1)
- checkFilterPushdown(rdd, '_1)(Literal(2) > '_1, classOf[Lt [Integer]])(1)
- checkFilterPushdown(rdd, '_1)(Literal(3) < '_1, classOf[Gt [Integer]])(4)
- checkFilterPushdown(rdd, '_1)(Literal(1) >= '_1, classOf[LtEq[Integer]])(1)
- checkFilterPushdown(rdd, '_1)(Literal(4) <= '_1, classOf[GtEq[Integer]])(4)
-
- checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[Integer]])(4)
- checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3)
- checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) {
- Seq(Row(1), Row(4))
- }
+ withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd =>
+ checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
+
+ checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1)
+ checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
+
+ checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4)
+ checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
+ checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
+
+ checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4)
+ checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
+ checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
+
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
+ checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
}
}
test("filter pushdown - long") {
- withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { rdd =>
- checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Long]])(Seq.empty[Row])
- checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Long]]) {
- (1 to 4).map(Row.apply(_))
- }
-
- checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Long]])(1)
- checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Long]]) {
- (2 to 4).map(Row.apply(_))
- }
-
- checkFilterPushdown(rdd, '_1)('_1 < 2, classOf[Lt [java.lang.Long]])(1)
- checkFilterPushdown(rdd, '_1)('_1 > 3, classOf[Gt [java.lang.Long]])(4)
- checkFilterPushdown(rdd, '_1)('_1 <= 1, classOf[LtEq[java.lang.Long]])(1)
- checkFilterPushdown(rdd, '_1)('_1 >= 4, classOf[GtEq[java.lang.Long]])(4)
-
- checkFilterPushdown(rdd, '_1)(Literal(1) === '_1, classOf[Eq [Integer]])(1)
- checkFilterPushdown(rdd, '_1)(Literal(2) > '_1, classOf[Lt [java.lang.Long]])(1)
- checkFilterPushdown(rdd, '_1)(Literal(3) < '_1, classOf[Gt [java.lang.Long]])(4)
- checkFilterPushdown(rdd, '_1)(Literal(1) >= '_1, classOf[LtEq[java.lang.Long]])(1)
- checkFilterPushdown(rdd, '_1)(Literal(4) <= '_1, classOf[GtEq[java.lang.Long]])(4)
-
- checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Long]])(4)
- checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3)
- checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) {
- Seq(Row(1), Row(4))
- }
+ withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd =>
+ checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
+
+ checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
+ checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
+
+ checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4)
+ checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
+ checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
+
+ checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4)
+ checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
+ checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
+
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
+ checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
}
}
test("filter pushdown - float") {
- withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { rdd =>
- checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Float]])(Seq.empty[Row])
- checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Float]]) {
- (1 to 4).map(Row.apply(_))
- }
-
- checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Float]])(1)
- checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Float]]) {
- (2 to 4).map(Row.apply(_))
- }
-
- checkFilterPushdown(rdd, '_1)('_1 < 2, classOf[Lt [java.lang.Float]])(1)
- checkFilterPushdown(rdd, '_1)('_1 > 3, classOf[Gt [java.lang.Float]])(4)
- checkFilterPushdown(rdd, '_1)('_1 <= 1, classOf[LtEq[java.lang.Float]])(1)
- checkFilterPushdown(rdd, '_1)('_1 >= 4, classOf[GtEq[java.lang.Float]])(4)
-
- checkFilterPushdown(rdd, '_1)(Literal(1) === '_1, classOf[Eq [Integer]])(1)
- checkFilterPushdown(rdd, '_1)(Literal(2) > '_1, classOf[Lt [java.lang.Float]])(1)
- checkFilterPushdown(rdd, '_1)(Literal(3) < '_1, classOf[Gt [java.lang.Float]])(4)
- checkFilterPushdown(rdd, '_1)(Literal(1) >= '_1, classOf[LtEq[java.lang.Float]])(1)
- checkFilterPushdown(rdd, '_1)(Literal(4) <= '_1, classOf[GtEq[java.lang.Float]])(4)
-
- checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Float]])(4)
- checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3)
- checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) {
- Seq(Row(1), Row(4))
- }
+ withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd =>
+ checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
+
+ checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1)
+ checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
+
+ checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4)
+ checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
+ checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
+
+ checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4)
+ checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
+ checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
+
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
+ checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
}
}
test("filter pushdown - double") {
- withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { rdd =>
- checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Double]])(Seq.empty[Row])
- checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Double]]) {
- (1 to 4).map(Row.apply(_))
- }
-
- checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Double]])(1)
- checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Double]]) {
- (2 to 4).map(Row.apply(_))
- }
-
- checkFilterPushdown(rdd, '_1)('_1 < 2, classOf[Lt [java.lang.Double]])(1)
- checkFilterPushdown(rdd, '_1)('_1 > 3, classOf[Gt [java.lang.Double]])(4)
- checkFilterPushdown(rdd, '_1)('_1 <= 1, classOf[LtEq[java.lang.Double]])(1)
- checkFilterPushdown(rdd, '_1)('_1 >= 4, classOf[GtEq[java.lang.Double]])(4)
-
- checkFilterPushdown(rdd, '_1)(Literal(1) === '_1, classOf[Eq[Integer]])(1)
- checkFilterPushdown(rdd, '_1)(Literal(2) > '_1, classOf[Lt [java.lang.Double]])(1)
- checkFilterPushdown(rdd, '_1)(Literal(3) < '_1, classOf[Gt [java.lang.Double]])(4)
- checkFilterPushdown(rdd, '_1)(Literal(1) >= '_1, classOf[LtEq[java.lang.Double]])(1)
- checkFilterPushdown(rdd, '_1)(Literal(4) <= '_1, classOf[GtEq[java.lang.Double]])(4)
-
- checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Double]])(4)
- checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3)
- checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) {
- Seq(Row(1), Row(4))
- }
+ withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd =>
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
+
+ checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1)
+ checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
+
+ checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4)
+ checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
+ checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
+
+ checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4)
+ checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
+ checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
+
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
+ checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
}
}
test("filter pushdown - string") {
- withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { rdd =>
- checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row])
- checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) {
- (1 to 4).map(i => Row.apply(i.toString))
- }
-
- checkFilterPushdown(rdd, '_1)('_1 === "1", classOf[Eq[String]])("1")
- checkFilterPushdown(rdd, '_1)('_1 !== "1", classOf[Operators.NotEq[String]]) {
- (2 to 4).map(i => Row.apply(i.toString))
- }
+ withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { implicit rdd =>
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(
+ '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString)))
+
+ checkFilterPredicate('_1 === "1", classOf[Eq [_]], "1")
+ checkFilterPredicate('_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString)))
+
+ checkFilterPredicate('_1 < "2", classOf[Lt [_]], "1")
+ checkFilterPredicate('_1 > "3", classOf[Gt [_]], "4")
+ checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1")
+ checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4")
+
+ checkFilterPredicate(Literal("1") === '_1, classOf[Eq [_]], "1")
+ checkFilterPredicate(Literal("2") > '_1, classOf[Lt [_]], "1")
+ checkFilterPredicate(Literal("3") < '_1, classOf[Gt [_]], "4")
+ checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1")
+ checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4")
+
+ checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4")
+ checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3")
+ checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4")))
+ }
+ }
- checkFilterPushdown(rdd, '_1)('_1 < "2", classOf[Lt [java.lang.String]])("1")
- checkFilterPushdown(rdd, '_1)('_1 > "3", classOf[Gt [java.lang.String]])("4")
- checkFilterPushdown(rdd, '_1)('_1 <= "1", classOf[LtEq[java.lang.String]])("1")
- checkFilterPushdown(rdd, '_1)('_1 >= "4", classOf[GtEq[java.lang.String]])("4")
-
- checkFilterPushdown(rdd, '_1)(Literal("1") === '_1, classOf[Eq [java.lang.String]])("1")
- checkFilterPushdown(rdd, '_1)(Literal("2") > '_1, classOf[Lt [java.lang.String]])("1")
- checkFilterPushdown(rdd, '_1)(Literal("3") < '_1, classOf[Gt [java.lang.String]])("4")
- checkFilterPushdown(rdd, '_1)(Literal("1") >= '_1, classOf[LtEq[java.lang.String]])("1")
- checkFilterPushdown(rdd, '_1)(Literal("4") <= '_1, classOf[GtEq[java.lang.String]])("4")
-
- checkFilterPushdown(rdd, '_1)(!('_1 < "4"), classOf[Operators.GtEq[java.lang.String]])("4")
- checkFilterPushdown(rdd, '_1)('_1 > "2" && '_1 < "4", classOf[Operators.And])("3")
- checkFilterPushdown(rdd, '_1)('_1 < "2" || '_1 > "3", classOf[Operators.Or]) {
- Seq(Row("1"), Row("4"))
+ def checkBinaryFilterPredicate
+ (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row])
+ (implicit rdd: SchemaRDD): Unit = {
+ def checkBinaryAnswer(rdd: SchemaRDD, expected: Seq[Row]) = {
+ assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) {
+ rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted
}
}
+
+ checkFilterPredicate(rdd, predicate, filterClass, checkBinaryAnswer _, expected)
+ }
+
+ def checkBinaryFilterPredicate
+ (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte])
+ (implicit rdd: SchemaRDD): Unit = {
+ checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd)
}
test("filter pushdown - binary") {
@@ -261,33 +240,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
def b: Array[Byte] = int.toString.getBytes("UTF-8")
}
- withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { rdd =>
- checkBinaryFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row])
- checkBinaryFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) {
- (1 to 4).map(i => Row.apply(i.b)).toSeq
- }
-
- checkBinaryFilterPushdown(rdd, '_1)('_1 === 1.b, classOf[Eq[Array[Byte]]])(1.b)
- checkBinaryFilterPushdown(rdd, '_1)('_1 !== 1.b, classOf[Operators.NotEq[Array[Byte]]]) {
- (2 to 4).map(i => Row.apply(i.b)).toSeq
- }
-
- checkBinaryFilterPushdown(rdd, '_1)('_1 < 2.b, classOf[Lt [Array[Byte]]])(1.b)
- checkBinaryFilterPushdown(rdd, '_1)('_1 > 3.b, classOf[Gt [Array[Byte]]])(4.b)
- checkBinaryFilterPushdown(rdd, '_1)('_1 <= 1.b, classOf[LtEq[Array[Byte]]])(1.b)
- checkBinaryFilterPushdown(rdd, '_1)('_1 >= 4.b, classOf[GtEq[Array[Byte]]])(4.b)
-
- checkBinaryFilterPushdown(rdd, '_1)(Literal(1.b) === '_1, classOf[Eq [Array[Byte]]])(1.b)
- checkBinaryFilterPushdown(rdd, '_1)(Literal(2.b) > '_1, classOf[Lt [Array[Byte]]])(1.b)
- checkBinaryFilterPushdown(rdd, '_1)(Literal(3.b) < '_1, classOf[Gt [Array[Byte]]])(4.b)
- checkBinaryFilterPushdown(rdd, '_1)(Literal(1.b) >= '_1, classOf[LtEq[Array[Byte]]])(1.b)
- checkBinaryFilterPushdown(rdd, '_1)(Literal(4.b) <= '_1, classOf[GtEq[Array[Byte]]])(4.b)
-
- checkBinaryFilterPushdown(rdd, '_1)(!('_1 < 4.b), classOf[Operators.GtEq[Array[Byte]]])(4.b)
- checkBinaryFilterPushdown(rdd, '_1)('_1 > 2.b && '_1 < 4.b, classOf[Operators.And])(3.b)
- checkBinaryFilterPushdown(rdd, '_1)('_1 < 2.b || '_1 > 3.b, classOf[Operators.Or]) {
- Seq(Row(1.b), Row(4.b))
- }
+ withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { implicit rdd =>
+ checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkBinaryFilterPredicate(
+ '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq)
+
+ checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq [_]], 1.b)
+ checkBinaryFilterPredicate(
+ '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq)
+
+ checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt [_]], 1.b)
+ checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt [_]], 4.b)
+ checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b)
+ checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b)
+
+ checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq [_]], 1.b)
+ checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt [_]], 1.b)
+ checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt [_]], 4.b)
+ checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b)
+ checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b)
+
+ checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b)
+ checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b)
+ checkBinaryFilterPredicate(
+ '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b)))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
index 973819aaa4d77..a57e4e85a35ef 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
@@ -68,8 +68,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
/**
* Writes `data` to a Parquet file, reads it back and check file contents.
*/
- protected def checkParquetFile[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = {
- withParquetRDD(data)(checkAnswer(_, data))
+ protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data: Seq[T]): Unit = {
+ withParquetRDD(data)(r => checkAnswer(r, data.map(Row.fromTuple)))
}
test("basic data types (without binary)") {
@@ -143,7 +143,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
withParquetRDD(data) { rdd =>
// Structs are converted to `Row`s
checkAnswer(rdd, data.map { case Tuple1(struct) =>
- Tuple1(Row(struct.productIterator.toSeq: _*))
+ Row(Row(struct.productIterator.toSeq: _*))
})
}
}
@@ -153,7 +153,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
withParquetRDD(data) { rdd =>
// Structs are converted to `Row`s
checkAnswer(rdd, data.map { case Tuple1(struct) =>
- Tuple1(Row(struct.productIterator.toSeq: _*))
+ Row(Row(struct.productIterator.toSeq: _*))
})
}
}
@@ -162,7 +162,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i"))))
withParquetRDD(data) { rdd =>
checkAnswer(rdd, data.map { case Tuple1(m) =>
- Tuple1(m.mapValues(struct => Row(struct.productIterator.toSeq: _*)))
+ Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*)))
})
}
}
@@ -261,7 +261,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
val path = new Path(dir.toURI.toString, "part-r-0.parquet")
makeRawParquetFile(path)
checkAnswer(parquetFile(path.toString), (0 until 10).map { i =>
- (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble)
+ Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble)
})
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 3a073a6b7057e..1263ff818ea19 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -17,1030 +17,72 @@
package org.apache.spark.sql.parquet
-import scala.reflect.ClassTag
-
-import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.hadoop.mapreduce.Job
-import org.scalatest.{BeforeAndAfterAll, FunSuiteLike}
-import parquet.filter2.predicate.{FilterPredicate, Operators}
-import parquet.hadoop.ParquetFileWriter
-import parquet.hadoop.util.ContextUtil
-import parquet.io.api.Binary
-
-import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.getTempFilePath
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
-
-case class TestRDDEntry(key: Int, value: String)
-
-case class NullReflectData(
- intField: java.lang.Integer,
- longField: java.lang.Long,
- floatField: java.lang.Float,
- doubleField: java.lang.Double,
- booleanField: java.lang.Boolean)
-
-case class OptionalReflectData(
- intField: Option[Int],
- longField: Option[Long],
- floatField: Option[Float],
- doubleField: Option[Double],
- booleanField: Option[Boolean])
-
-case class Nested(i: Int, s: String)
-
-case class Data(array: Seq[Int], nested: Nested)
-
-case class AllDataTypes(
- stringField: String,
- intField: Int,
- longField: Long,
- floatField: Float,
- doubleField: Double,
- shortField: Short,
- byteField: Byte,
- booleanField: Boolean)
-
-case class AllDataTypesWithNonPrimitiveType(
- stringField: String,
- intField: Int,
- longField: Long,
- floatField: Float,
- doubleField: Double,
- shortField: Short,
- byteField: Byte,
- booleanField: Boolean,
- array: Seq[Int],
- arrayContainsNull: Seq[Option[Int]],
- map: Map[Int, Long],
- mapValueContainsNull: Map[Int, Option[Long]],
- data: Data)
-
-case class BinaryData(binaryData: Array[Byte])
-
-case class NumericData(i: Int, d: Double)
-class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll {
- TestData // Load test data tables.
-
- private var testRDD: SchemaRDD = null
- private val originalParquetFilterPushdownEnabled = TestSQLContext.conf.parquetFilterPushDown
-
- override def beforeAll() {
- ParquetTestData.writeFile()
- ParquetTestData.writeFilterFile()
- ParquetTestData.writeNestedFile1()
- ParquetTestData.writeNestedFile2()
- ParquetTestData.writeNestedFile3()
- ParquetTestData.writeNestedFile4()
- ParquetTestData.writeGlobFiles()
- testRDD = parquetFile(ParquetTestData.testDir.toString)
- testRDD.registerTempTable("testsource")
- parquetFile(ParquetTestData.testFilterDir.toString)
- .registerTempTable("testfiltersource")
-
- setConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED, "true")
- }
-
- override def afterAll() {
- Utils.deleteRecursively(ParquetTestData.testDir)
- Utils.deleteRecursively(ParquetTestData.testFilterDir)
- Utils.deleteRecursively(ParquetTestData.testNestedDir1)
- Utils.deleteRecursively(ParquetTestData.testNestedDir2)
- Utils.deleteRecursively(ParquetTestData.testNestedDir3)
- Utils.deleteRecursively(ParquetTestData.testNestedDir4)
- Utils.deleteRecursively(ParquetTestData.testGlobDir)
- // here we should also unregister the table??
-
- setConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED, originalParquetFilterPushdownEnabled.toString)
- }
+/**
+ * A test suite that tests various Parquet queries.
+ */
+class ParquetQuerySuite extends QueryTest with ParquetTest {
+ val sqlContext = TestSQLContext
- test("Read/Write All Types") {
- val tempDir = getTempFilePath("parquetTest").getCanonicalPath
- val range = (0 to 255)
- val data = sparkContext.parallelize(range).map { x =>
- parquet.AllDataTypes(
- s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)
+ test("simple projection") {
+ withParquetTable((0 until 10).map(i => (i, i.toString)), "t") {
+ checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_)))
}
-
- data.saveAsParquetFile(tempDir)
-
- checkAnswer(
- parquetFile(tempDir),
- data.toSchemaRDD.collect().toSeq)
}
- test("read/write binary data") {
- // Since equality for Array[Byte] is broken we test this separately.
- val tempDir = getTempFilePath("parquetTest").getCanonicalPath
- sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil).saveAsParquetFile(tempDir)
- parquetFile(tempDir)
- .map(r => new String(r(0).asInstanceOf[Array[Byte]], "utf8"))
- .collect().toSeq == Seq("test")
- }
-
- ignore("Treat binary as string") {
- val oldIsParquetBinaryAsString = TestSQLContext.conf.isParquetBinaryAsString
-
- // Create the test file.
- val file = getTempFilePath("parquet")
- val path = file.toString
- val range = (0 to 255)
- val rowRDD = TestSQLContext.sparkContext.parallelize(range)
- .map(i => org.apache.spark.sql.Row(i, s"val_$i".getBytes))
- // We need to ask Parquet to store the String column as a Binary column.
- val schema = StructType(
- StructField("c1", IntegerType, false) ::
- StructField("c2", BinaryType, false) :: Nil)
- val schemaRDD1 = applySchema(rowRDD, schema)
- schemaRDD1.saveAsParquetFile(path)
- checkAnswer(
- parquetFile(path).select('c1, 'c2.cast(StringType)),
- schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq)
-
- setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true")
- parquetFile(path).printSchema()
- checkAnswer(
- parquetFile(path),
- schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq)
-
-
- // Set it back.
- TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, oldIsParquetBinaryAsString.toString)
- }
-
- test("Compression options for writing to a Parquetfile") {
- val defaultParquetCompressionCodec = TestSQLContext.conf.parquetCompressionCodec
- import scala.collection.JavaConversions._
-
- val file = getTempFilePath("parquet")
- val path = file.toString
- val rdd = TestSQLContext.sparkContext.parallelize((1 to 100))
- .map(i => TestRDDEntry(i, s"val_$i"))
-
- // test default compression codec
- rdd.saveAsParquetFile(path)
- var actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration))
- .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct
- assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil)
-
- parquetFile(path).registerTempTable("tmp")
- checkAnswer(
- sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"),
- (5, "val_5") ::
- (7, "val_7") :: Nil)
-
- Utils.deleteRecursively(file)
-
- // test uncompressed parquet file with property value "UNCOMPRESSED"
- TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "UNCOMPRESSED")
-
- rdd.saveAsParquetFile(path)
- actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration))
- .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct
- assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil)
-
- parquetFile(path).registerTempTable("tmp")
- checkAnswer(
- sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"),
- (5, "val_5") ::
- (7, "val_7") :: Nil)
-
- Utils.deleteRecursively(file)
-
- // test uncompressed parquet file with property value "none"
- TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "none")
-
- rdd.saveAsParquetFile(path)
- actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration))
- .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct
- assert(actualCodec === "UNCOMPRESSED" :: Nil)
-
- parquetFile(path).registerTempTable("tmp")
- checkAnswer(
- sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"),
- (5, "val_5") ::
- (7, "val_7") :: Nil)
-
- Utils.deleteRecursively(file)
-
- // test gzip compression codec
- TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "gzip")
-
- rdd.saveAsParquetFile(path)
- actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration))
- .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct
- assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil)
-
- parquetFile(path).registerTempTable("tmp")
- checkAnswer(
- sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"),
- (5, "val_5") ::
- (7, "val_7") :: Nil)
-
- Utils.deleteRecursively(file)
-
- // test snappy compression codec
- TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "snappy")
-
- rdd.saveAsParquetFile(path)
- actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration))
- .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct
- assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil)
-
- parquetFile(path).registerTempTable("tmp")
- checkAnswer(
- sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"),
- (5, "val_5") ::
- (7, "val_7") :: Nil)
-
- Utils.deleteRecursively(file)
-
- // TODO: Lzo requires additional external setup steps so leave it out for now
- // ref.: https://github.com/Parquet/parquet-mr/blob/parquet-1.5.0/parquet-hadoop/src/test/java/parquet/hadoop/example/TestInputOutputFormat.java#L169
-
- // Set it back.
- TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, defaultParquetCompressionCodec)
- }
-
- test("Read/Write All Types with non-primitive type") {
- val tempDir = getTempFilePath("parquetTest").getCanonicalPath
- val range = (0 to 255)
- val data = sparkContext.parallelize(range).map { x =>
- parquet.AllDataTypesWithNonPrimitiveType(
- s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0,
- (0 until x),
- (0 until x).map(Option(_).filter(_ % 3 == 0)),
- (0 until x).map(i => i -> i.toLong).toMap,
- (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None),
- parquet.Data((0 until x), parquet.Nested(x, s"$x")))
+ test("appending") {
+ val data = (0 until 10).map(i => (i, i.toString))
+ withParquetTable(data, "t") {
+ sql("INSERT INTO t SELECT * FROM t")
+ checkAnswer(table("t"), (data ++ data).map(Row.fromTuple))
}
- data.saveAsParquetFile(tempDir)
-
- checkAnswer(
- parquetFile(tempDir),
- data.toSchemaRDD.collect().toSeq)
}
- test("self-join parquet files") {
- val x = ParquetTestData.testData.as('x)
- val y = ParquetTestData.testData.as('y)
- val query = x.join(y).where("x.myint".attr === "y.myint".attr)
-
- // Check to make sure that the attributes from either side of the join have unique expression
- // ids.
- query.queryExecution.analyzed.output.filter(_.name == "myint") match {
- case Seq(i1, i2) if(i1.exprId == i2.exprId) =>
- fail(s"Duplicate expression IDs found in query plan: $query")
- case Seq(_, _) => // All good
+ test("self-join") {
+ // 4 rows, cells of column 1 of row 2 and row 4 are null
+ val data = (1 to 4).map { i =>
+ val maybeInt = if (i % 2 == 0) None else Some(i)
+ (maybeInt, i.toString)
}
- val result = query.collect()
- assert(result.size === 9, "self-join result has incorrect size")
- assert(result(0).size === 12, "result row has incorrect size")
- result.zipWithIndex.foreach {
- case (row, index) => row.zipWithIndex.foreach {
- case (field, column) => assert(field != null, s"self-join contains null value in row $index field $column")
- }
- }
- }
+ withParquetTable(data, "t") {
+ val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1")
+ val queryOutput = selfJoin.queryExecution.analyzed.output
- test("Import of simple Parquet file") {
- val result = parquetFile(ParquetTestData.testDir.toString).collect()
- assert(result.size === 15)
- result.zipWithIndex.foreach {
- case (row, index) => {
- val checkBoolean =
- if (index % 3 == 0)
- row(0) == true
- else
- row(0) == false
- assert(checkBoolean === true, s"boolean field value in line $index did not match")
- if (index % 5 == 0) assert(row(1) === 5, s"int field value in line $index did not match")
- assert(row(2) === "abc", s"string field value in line $index did not match")
- assert(row(3) === (index.toLong << 33), s"long value in line $index did not match")
- assert(row(4) === 2.5F, s"float field value in line $index did not match")
- assert(row(5) === 4.5D, s"double field value in line $index did not match")
+ assertResult(4, s"Field count mismatches")(queryOutput.size)
+ assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") {
+ queryOutput.filter(_.name == "_1").map(_.exprId).size
}
- }
- }
- test("Projection of simple Parquet file") {
- val result = ParquetTestData.testData.select('myboolean, 'mylong).collect()
- result.zipWithIndex.foreach {
- case (row, index) => {
- if (index % 3 == 0)
- assert(row(0) === true, s"boolean field value in line $index did not match (every third row)")
- else
- assert(row(0) === false, s"boolean field value in line $index did not match")
- assert(row(1) === (index.toLong << 33), s"long field value in line $index did not match")
- assert(row.size === 2, s"number of columns in projection in line $index is incorrect")
- }
+ checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3")))
}
}
- test("Writing metadata from scratch for table CREATE") {
- val job = new Job()
- val path = new Path(getTempFilePath("testtable").getCanonicalFile.toURI.toString)
- val fs: FileSystem = FileSystem.getLocal(ContextUtil.getConfiguration(job))
- ParquetTypesConverter.writeMetaData(
- ParquetTestData.testData.output,
- path,
- TestSQLContext.sparkContext.hadoopConfiguration)
- assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE)))
- val metaData = ParquetTypesConverter.readMetaData(path, Some(ContextUtil.getConfiguration(job)))
- assert(metaData != null)
- ParquetTestData
- .testData
- .parquetSchema
- .checkContains(metaData.getFileMetaData.getSchema) // throws exception if incompatible
- metaData
- .getFileMetaData
- .getSchema
- .checkContains(ParquetTestData.testData.parquetSchema) // throws exception if incompatible
- fs.delete(path, true)
- }
-
- test("Creating case class RDD table") {
- TestSQLContext.sparkContext.parallelize((1 to 100))
- .map(i => TestRDDEntry(i, s"val_$i"))
- .registerTempTable("tmp")
- val rdd = sql("SELECT * FROM tmp").collect().sortBy(_.getInt(0))
- var counter = 1
- rdd.foreach {
- // '===' does not like string comparison?
- row: Row => {
- assert(row.getString(1).equals(s"val_$counter"), s"row $counter value ${row.getString(1)} does not match val_$counter")
- counter = counter + 1
- }
+ test("nested data - struct with array field") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withParquetTable(data, "t") {
+ checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map {
+ case Tuple1((_, Seq(string))) => Row(string)
+ })
}
}
- test("Read a parquet file instead of a directory") {
- val file = getTempFilePath("parquet")
- val path = file.toString
- val fsPath = new Path(path)
- val fs: FileSystem = fsPath.getFileSystem(TestSQLContext.sparkContext.hadoopConfiguration)
- val rdd = TestSQLContext.sparkContext.parallelize((1 to 100))
- .map(i => TestRDDEntry(i, s"val_$i"))
- rdd.coalesce(1).saveAsParquetFile(path)
-
- val children = fs.listStatus(fsPath).filter(_.getPath.getName.endsWith(".parquet"))
- assert(children.length > 0)
- val readFile = parquetFile(path + "/" + children(0).getPath.getName)
- readFile.registerTempTable("tmpx")
- val rdd_copy = sql("SELECT * FROM tmpx").collect()
- val rdd_orig = rdd.collect()
- for(i <- 0 to 99) {
- assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i")
- assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value error in line $i")
+ test("nested data - array of struct") {
+ val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i")))
+ withParquetTable(data, "t") {
+ checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map {
+ case Tuple1(Seq((_, string))) => Row(string)
+ })
}
- Utils.deleteRecursively(file)
- }
-
- test("Insert (appending) to same table via Scala API") {
- sql("INSERT INTO testsource SELECT * FROM testsource")
- val double_rdd = sql("SELECT * FROM testsource").collect()
- assert(double_rdd != null)
- assert(double_rdd.size === 30)
-
- // let's restore the original test data
- Utils.deleteRecursively(ParquetTestData.testDir)
- ParquetTestData.writeFile()
- }
-
- test("save and load case class RDD with nulls as parquet") {
- val data = parquet.NullReflectData(null, null, null, null, null)
- val rdd = sparkContext.parallelize(data :: Nil)
-
- val file = getTempFilePath("parquet")
- val path = file.toString
- rdd.saveAsParquetFile(path)
- val readFile = parquetFile(path)
-
- val rdd_saved = readFile.collect()
- assert(rdd_saved(0) === Seq.fill(5)(null))
- Utils.deleteRecursively(file)
- assert(true)
- }
-
- test("save and load case class RDD with Nones as parquet") {
- val data = parquet.OptionalReflectData(None, None, None, None, None)
- val rdd = sparkContext.parallelize(data :: Nil)
-
- val file = getTempFilePath("parquet")
- val path = file.toString
- rdd.saveAsParquetFile(path)
- val readFile = parquetFile(path)
-
- val rdd_saved = readFile.collect()
- assert(rdd_saved(0) === Seq.fill(5)(null))
- Utils.deleteRecursively(file)
- assert(true)
- }
-
- test("make RecordFilter for simple predicates") {
- def checkFilter[T <: FilterPredicate : ClassTag](
- predicate: Expression,
- defined: Boolean = true): Unit = {
- val filter = ParquetFilters.createFilter(predicate)
- if (defined) {
- assert(filter.isDefined)
- val tClass = implicitly[ClassTag[T]].runtimeClass
- val filterGet = filter.get
- assert(
- tClass.isInstance(filterGet),
- s"$filterGet of type ${filterGet.getClass} is not an instance of $tClass")
- } else {
- assert(filter.isEmpty)
- }
- }
-
- checkFilter[Operators.Eq[Integer]]('a.int === 1)
- checkFilter[Operators.Eq[Integer]](Literal(1) === 'a.int)
-
- checkFilter[Operators.Lt[Integer]]('a.int < 4)
- checkFilter[Operators.Lt[Integer]](Literal(4) > 'a.int)
- checkFilter[Operators.LtEq[Integer]]('a.int <= 4)
- checkFilter[Operators.LtEq[Integer]](Literal(4) >= 'a.int)
-
- checkFilter[Operators.Gt[Integer]]('a.int > 4)
- checkFilter[Operators.Gt[Integer]](Literal(4) < 'a.int)
- checkFilter[Operators.GtEq[Integer]]('a.int >= 4)
- checkFilter[Operators.GtEq[Integer]](Literal(4) <= 'a.int)
-
- checkFilter[Operators.And]('a.int === 1 && 'a.int < 4)
- checkFilter[Operators.Or]('a.int === 1 || 'a.int < 4)
- checkFilter[Operators.NotEq[Integer]](!('a.int === 1))
-
- checkFilter('a.int > 'b.int, defined = false)
- checkFilter(('a.int > 'b.int) && ('a.int > 'b.int), defined = false)
- }
-
- test("test filter by predicate pushdown") {
- for(myval <- Seq("myint", "mylong", "mydouble", "myfloat")) {
- val query1 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100")
- assert(
- query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
- "Top operator should be ParquetTableScan after pushdown")
- val result1 = query1.collect()
- assert(result1.size === 50)
- assert(result1(0)(1) === 100)
- assert(result1(49)(1) === 149)
- val query2 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200")
- assert(
- query2.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
- "Top operator should be ParquetTableScan after pushdown")
- val result2 = query2.collect()
- assert(result2.size === 50)
- if (myval == "myint" || myval == "mylong") {
- assert(result2(0)(1) === 151)
- assert(result2(49)(1) === 200)
- } else {
- assert(result2(0)(1) === 150)
- assert(result2(49)(1) === 199)
- }
- }
- for(myval <- Seq("myint", "mylong")) {
- val query3 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190 OR $myval < 10")
- assert(
- query3.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
- "Top operator should be ParquetTableScan after pushdown")
- val result3 = query3.collect()
- assert(result3.size === 20)
- assert(result3(0)(1) === 0)
- assert(result3(9)(1) === 9)
- assert(result3(10)(1) === 191)
- assert(result3(19)(1) === 200)
- }
- for(myval <- Seq("mydouble", "myfloat")) {
- val result4 =
- if (myval == "mydouble") {
- val query4 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10.0")
- assert(
- query4.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
- "Top operator should be ParquetTableScan after pushdown")
- query4.collect()
- } else {
- // CASTs are problematic. Here myfloat will be casted to a double and it seems there is
- // currently no way to specify float constants in SqlParser?
- sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10").collect()
- }
- assert(result4.size === 20)
- assert(result4(0)(1) === 0)
- assert(result4(9)(1) === 9)
- assert(result4(10)(1) === 191)
- assert(result4(19)(1) === 200)
- }
- val query5 = sql(s"SELECT * FROM testfiltersource WHERE myboolean = true AND myint < 40")
- assert(
- query5.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
- "Top operator should be ParquetTableScan after pushdown")
- val booleanResult = query5.collect()
- assert(booleanResult.size === 10)
- for(i <- 0 until 10) {
- if (!booleanResult(i).getBoolean(0)) {
- fail(s"Boolean value in result row $i not true")
- }
- if (booleanResult(i).getInt(1) != i * 4) {
- fail(s"Int value in result row $i should be ${4*i}")
- }
- }
- val query6 = sql("SELECT * FROM testfiltersource WHERE mystring = \"100\"")
- assert(
- query6.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
- "Top operator should be ParquetTableScan after pushdown")
- val stringResult = query6.collect()
- assert(stringResult.size === 1)
- assert(stringResult(0).getString(2) == "100", "stringvalue incorrect")
- assert(stringResult(0).getInt(1) === 100)
-
- val query7 = sql(s"SELECT * FROM testfiltersource WHERE myoptint < 40")
- assert(
- query7.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
- "Top operator should be ParquetTableScan after pushdown")
- val optResult = query7.collect()
- assert(optResult.size === 20)
- for(i <- 0 until 20) {
- if (optResult(i)(7) != i * 2) {
- fail(s"optional Int value in result row $i should be ${2*4*i}")
- }
- }
- for(myval <- Seq("myoptint", "myoptlong", "myoptdouble", "myoptfloat")) {
- val query8 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100")
- assert(
- query8.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
- "Top operator should be ParquetTableScan after pushdown")
- val result8 = query8.collect()
- assert(result8.size === 25)
- assert(result8(0)(7) === 100)
- assert(result8(24)(7) === 148)
- val query9 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200")
- assert(
- query9.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
- "Top operator should be ParquetTableScan after pushdown")
- val result9 = query9.collect()
- assert(result9.size === 25)
- if (myval == "myoptint" || myval == "myoptlong") {
- assert(result9(0)(7) === 152)
- assert(result9(24)(7) === 200)
- } else {
- assert(result9(0)(7) === 150)
- assert(result9(24)(7) === 198)
- }
- }
- val query10 = sql("SELECT * FROM testfiltersource WHERE myoptstring = \"100\"")
- assert(
- query10.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
- "Top operator should be ParquetTableScan after pushdown")
- val result10 = query10.collect()
- assert(result10.size === 1)
- assert(result10(0).getString(8) == "100", "stringvalue incorrect")
- assert(result10(0).getInt(7) === 100)
- val query11 = sql(s"SELECT * FROM testfiltersource WHERE myoptboolean = true AND myoptint < 40")
- assert(
- query11.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
- "Top operator should be ParquetTableScan after pushdown")
- val result11 = query11.collect()
- assert(result11.size === 7)
- for(i <- 0 until 6) {
- if (!result11(i).getBoolean(6)) {
- fail(s"optional Boolean value in result row $i not true")
- }
- if (result11(i).getInt(7) != i * 6) {
- fail(s"optional Int value in result row $i should be ${6*i}")
- }
- }
-
- val query12 = sql("SELECT * FROM testfiltersource WHERE mystring >= \"50\"")
- assert(
- query12.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
- "Top operator should be ParquetTableScan after pushdown")
- val result12 = query12.collect()
- assert(result12.size === 54)
- assert(result12(0).getString(2) == "6")
- assert(result12(4).getString(2) == "50")
- assert(result12(53).getString(2) == "99")
-
- val query13 = sql("SELECT * FROM testfiltersource WHERE mystring > \"50\"")
- assert(
- query13.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
- "Top operator should be ParquetTableScan after pushdown")
- val result13 = query13.collect()
- assert(result13.size === 53)
- assert(result13(0).getString(2) == "6")
- assert(result13(4).getString(2) == "51")
- assert(result13(52).getString(2) == "99")
-
- val query14 = sql("SELECT * FROM testfiltersource WHERE mystring <= \"50\"")
- assert(
- query14.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
- "Top operator should be ParquetTableScan after pushdown")
- val result14 = query14.collect()
- assert(result14.size === 148)
- assert(result14(0).getString(2) == "0")
- assert(result14(46).getString(2) == "50")
- assert(result14(147).getString(2) == "200")
-
- val query15 = sql("SELECT * FROM testfiltersource WHERE mystring < \"50\"")
- assert(
- query15.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
- "Top operator should be ParquetTableScan after pushdown")
- val result15 = query15.collect()
- assert(result15.size === 147)
- assert(result15(0).getString(2) == "0")
- assert(result15(46).getString(2) == "100")
- assert(result15(146).getString(2) == "200")
}
test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") {
- val query = sql(s"SELECT mystring FROM testfiltersource WHERE myint < 10")
- assert(query.collect().size === 10)
- }
-
- test("Importing nested Parquet file (Addressbook)") {
- val result = TestSQLContext
- .parquetFile(ParquetTestData.testNestedDir1.toString)
- .toSchemaRDD
- .collect()
- assert(result != null)
- assert(result.size === 2)
- val first_record = result(0)
- val second_record = result(1)
- assert(first_record != null)
- assert(second_record != null)
- assert(first_record.size === 3)
- assert(second_record(1) === null)
- assert(second_record(2) === null)
- assert(second_record(0) === "A. Nonymous")
- assert(first_record(0) === "Julien Le Dem")
- val first_owner_numbers = first_record(1)
- .asInstanceOf[CatalystConverter.ArrayScalaType[_]]
- val first_contacts = first_record(2)
- .asInstanceOf[CatalystConverter.ArrayScalaType[_]]
- assert(first_owner_numbers != null)
- assert(first_owner_numbers(0) === "555 123 4567")
- assert(first_owner_numbers(2) === "XXX XXX XXXX")
- assert(first_contacts(0)
- .asInstanceOf[CatalystConverter.StructScalaType[_]].size === 2)
- val first_contacts_entry_one = first_contacts(0)
- .asInstanceOf[CatalystConverter.StructScalaType[_]]
- assert(first_contacts_entry_one(0) === "Dmitriy Ryaboy")
- assert(first_contacts_entry_one(1) === "555 987 6543")
- val first_contacts_entry_two = first_contacts(1)
- .asInstanceOf[CatalystConverter.StructScalaType[_]]
- assert(first_contacts_entry_two(0) === "Chris Aniszczyk")
- }
-
- test("Importing nested Parquet file (nested numbers)") {
- val result = TestSQLContext
- .parquetFile(ParquetTestData.testNestedDir2.toString)
- .toSchemaRDD
- .collect()
- assert(result.size === 1, "number of top-level rows incorrect")
- assert(result(0).size === 5, "number of fields in row incorrect")
- assert(result(0)(0) === 1)
- assert(result(0)(1) === 7)
- val subresult1 = result(0)(2).asInstanceOf[CatalystConverter.ArrayScalaType[_]]
- assert(subresult1.size === 3)
- assert(subresult1(0) === (1.toLong << 32))
- assert(subresult1(1) === (1.toLong << 33))
- assert(subresult1(2) === (1.toLong << 34))
- val subresult2 = result(0)(3)
- .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)
- .asInstanceOf[CatalystConverter.StructScalaType[_]]
- assert(subresult2.size === 2)
- assert(subresult2(0) === 2.5)
- assert(subresult2(1) === false)
- val subresult3 = result(0)(4)
- .asInstanceOf[CatalystConverter.ArrayScalaType[_]]
- assert(subresult3.size === 2)
- assert(subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 2)
- val subresult4 = subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]]
- assert(subresult4(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7)
- assert(subresult4(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8)
- assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 1)
- assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)
- .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9)
- }
-
- test("Simple query on addressbook") {
- val data = TestSQLContext
- .parquetFile(ParquetTestData.testNestedDir1.toString)
- .toSchemaRDD
- val tmp = data.where('owner === "Julien Le Dem").select('owner as 'a, 'contacts as 'c).collect()
- assert(tmp.size === 1)
- assert(tmp(0)(0) === "Julien Le Dem")
- }
-
- test("Projection in addressbook") {
- val data = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD
- data.registerTempTable("data")
- val query = sql("SELECT owner, contacts[1].name FROM data")
- val tmp = query.collect()
- assert(tmp.size === 2)
- assert(tmp(0).size === 2)
- assert(tmp(0)(0) === "Julien Le Dem")
- assert(tmp(0)(1) === "Chris Aniszczyk")
- assert(tmp(1)(0) === "A. Nonymous")
- assert(tmp(1)(1) === null)
- }
-
- test("Simple query on nested int data") {
- val data = parquetFile(ParquetTestData.testNestedDir2.toString).toSchemaRDD
- data.registerTempTable("data")
- val result1 = sql("SELECT entries[0].value FROM data").collect()
- assert(result1.size === 1)
- assert(result1(0).size === 1)
- assert(result1(0)(0) === 2.5)
- val result2 = sql("SELECT entries[0] FROM data").collect()
- assert(result2.size === 1)
- val subresult1 = result2(0)(0).asInstanceOf[CatalystConverter.StructScalaType[_]]
- assert(subresult1.size === 2)
- assert(subresult1(0) === 2.5)
- assert(subresult1(1) === false)
- val result3 = sql("SELECT outerouter FROM data").collect()
- val subresult2 = result3(0)(0)
- .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)
- .asInstanceOf[CatalystConverter.ArrayScalaType[_]]
- assert(subresult2(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7)
- assert(subresult2(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8)
- assert(result3(0)(0)
- .asInstanceOf[CatalystConverter.ArrayScalaType[_]](1)
- .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)
- .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9)
- }
-
- test("nested structs") {
- val data = parquetFile(ParquetTestData.testNestedDir3.toString)
- .toSchemaRDD
- data.registerTempTable("data")
- val result1 = sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect()
- assert(result1.size === 1)
- assert(result1(0).size === 1)
- assert(result1(0)(0) === false)
- val result2 = sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect()
- assert(result2.size === 1)
- assert(result2(0).size === 1)
- assert(result2(0)(0) === true)
- val result3 = sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect()
- assert(result3.size === 1)
- assert(result3(0).size === 1)
- assert(result3(0)(0) === false)
- }
-
- test("simple map") {
- val data = TestSQLContext
- .parquetFile(ParquetTestData.testNestedDir4.toString)
- .toSchemaRDD
- data.registerTempTable("mapTable")
- val result1 = sql("SELECT data1 FROM mapTable").collect()
- assert(result1.size === 1)
- assert(result1(0)(0)
- .asInstanceOf[CatalystConverter.MapScalaType[String, _]]
- .getOrElse("key1", 0) === 1)
- assert(result1(0)(0)
- .asInstanceOf[CatalystConverter.MapScalaType[String, _]]
- .getOrElse("key2", 0) === 2)
- val result2 = sql("""SELECT data1["key1"] FROM mapTable""").collect()
- assert(result2(0)(0) === 1)
- }
-
- test("map with struct values") {
- val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD
- data.registerTempTable("mapTable")
- val result1 = sql("SELECT data2 FROM mapTable").collect()
- assert(result1.size === 1)
- val entry1 = result1(0)(0)
- .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]]
- .getOrElse("seven", null)
- assert(entry1 != null)
- assert(entry1(0) === 42)
- assert(entry1(1) === "the answer")
- val entry2 = result1(0)(0)
- .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]]
- .getOrElse("eight", null)
- assert(entry2 != null)
- assert(entry2(0) === 49)
- assert(entry2(1) === null)
- val result2 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect()
- assert(result2.size === 1)
- assert(result2(0)(0) === 42.toLong)
- assert(result2(0)(1) === "the answer")
- }
-
- test("Writing out Addressbook and reading it back in") {
- // TODO: find out why CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME
- // has no effect in this test case
- val tmpdir = Utils.createTempDir()
- Utils.deleteRecursively(tmpdir)
- val result = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD
- result.saveAsParquetFile(tmpdir.toString)
- parquetFile(tmpdir.toString)
- .toSchemaRDD
- .registerTempTable("tmpcopy")
- val tmpdata = sql("SELECT owner, contacts[1].name FROM tmpcopy").collect()
- assert(tmpdata.size === 2)
- assert(tmpdata(0).size === 2)
- assert(tmpdata(0)(0) === "Julien Le Dem")
- assert(tmpdata(0)(1) === "Chris Aniszczyk")
- assert(tmpdata(1)(0) === "A. Nonymous")
- assert(tmpdata(1)(1) === null)
- Utils.deleteRecursively(tmpdir)
- }
-
- test("Writing out Map and reading it back in") {
- val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD
- val tmpdir = Utils.createTempDir()
- Utils.deleteRecursively(tmpdir)
- data.saveAsParquetFile(tmpdir.toString)
- parquetFile(tmpdir.toString)
- .toSchemaRDD
- .registerTempTable("tmpmapcopy")
- val result1 = sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect()
- assert(result1.size === 1)
- assert(result1(0)(0) === 2)
- val result2 = sql("SELECT data2 FROM tmpmapcopy").collect()
- assert(result2.size === 1)
- val entry1 = result2(0)(0)
- .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]]
- .getOrElse("seven", null)
- assert(entry1 != null)
- assert(entry1(0) === 42)
- assert(entry1(1) === "the answer")
- val entry2 = result2(0)(0)
- .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]]
- .getOrElse("eight", null)
- assert(entry2 != null)
- assert(entry2(0) === 49)
- assert(entry2(1) === null)
- val result3 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect()
- assert(result3.size === 1)
- assert(result3(0)(0) === 42.toLong)
- assert(result3(0)(1) === "the answer")
- Utils.deleteRecursively(tmpdir)
- }
-
- test("read/write fixed-length decimals") {
- for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) {
- val tempDir = getTempFilePath("parquetTest").getCanonicalPath
- val data = sparkContext.parallelize(0 to 1000)
- .map(i => NumericData(i, i / 100.0))
- .select('i, 'd cast DecimalType(precision, scale))
- data.saveAsParquetFile(tempDir)
- checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq)
- }
-
- // Decimals with precision above 18 are not yet supported
- intercept[RuntimeException] {
- val tempDir = getTempFilePath("parquetTest").getCanonicalPath
- val data = sparkContext.parallelize(0 to 1000)
- .map(i => NumericData(i, i / 100.0))
- .select('i, 'd cast DecimalType(19, 10))
- data.saveAsParquetFile(tempDir)
- checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq)
- }
-
- // Unlimited-length decimals are not yet supported
- intercept[RuntimeException] {
- val tempDir = getTempFilePath("parquetTest").getCanonicalPath
- val data = sparkContext.parallelize(0 to 1000)
- .map(i => NumericData(i, i / 100.0))
- .select('i, 'd cast DecimalType.Unlimited)
- data.saveAsParquetFile(tempDir)
- checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq)
- }
- }
-
- def checkFilter(predicate: Predicate, filterClass: Class[_ <: FilterPredicate]): Unit = {
- val filter = ParquetFilters.createFilter(predicate)
- assert(filter.isDefined)
- assert(filter.get.getClass == filterClass)
- }
-
- test("Pushdown IsNull predicate") {
- checkFilter('a.int.isNull, classOf[Operators.Eq[Integer]])
- checkFilter('a.long.isNull, classOf[Operators.Eq[java.lang.Long]])
- checkFilter('a.float.isNull, classOf[Operators.Eq[java.lang.Float]])
- checkFilter('a.double.isNull, classOf[Operators.Eq[java.lang.Double]])
- checkFilter('a.string.isNull, classOf[Operators.Eq[Binary]])
- checkFilter('a.binary.isNull, classOf[Operators.Eq[Binary]])
- }
-
- test("Pushdown IsNotNull predicate") {
- checkFilter('a.int.isNotNull, classOf[Operators.NotEq[Integer]])
- checkFilter('a.long.isNotNull, classOf[Operators.NotEq[java.lang.Long]])
- checkFilter('a.float.isNotNull, classOf[Operators.NotEq[java.lang.Float]])
- checkFilter('a.double.isNotNull, classOf[Operators.NotEq[java.lang.Double]])
- checkFilter('a.string.isNotNull, classOf[Operators.NotEq[Binary]])
- checkFilter('a.binary.isNotNull, classOf[Operators.NotEq[Binary]])
- }
-
- test("Pushdown EqualTo predicate") {
- checkFilter('a.int === 0, classOf[Operators.Eq[Integer]])
- checkFilter('a.long === 0.toLong, classOf[Operators.Eq[java.lang.Long]])
- checkFilter('a.float === 0.toFloat, classOf[Operators.Eq[java.lang.Float]])
- checkFilter('a.double === 0.toDouble, classOf[Operators.Eq[java.lang.Double]])
- checkFilter('a.string === "foo", classOf[Operators.Eq[Binary]])
- checkFilter('a.binary === "foo".getBytes, classOf[Operators.Eq[Binary]])
- }
-
- test("Pushdown Not(EqualTo) predicate") {
- checkFilter(!('a.int === 0), classOf[Operators.NotEq[Integer]])
- checkFilter(!('a.long === 0.toLong), classOf[Operators.NotEq[java.lang.Long]])
- checkFilter(!('a.float === 0.toFloat), classOf[Operators.NotEq[java.lang.Float]])
- checkFilter(!('a.double === 0.toDouble), classOf[Operators.NotEq[java.lang.Double]])
- checkFilter(!('a.string === "foo"), classOf[Operators.NotEq[Binary]])
- checkFilter(!('a.binary === "foo".getBytes), classOf[Operators.NotEq[Binary]])
- }
-
- test("Pushdown LessThan predicate") {
- checkFilter('a.int < 0, classOf[Operators.Lt[Integer]])
- checkFilter('a.long < 0.toLong, classOf[Operators.Lt[java.lang.Long]])
- checkFilter('a.float < 0.toFloat, classOf[Operators.Lt[java.lang.Float]])
- checkFilter('a.double < 0.toDouble, classOf[Operators.Lt[java.lang.Double]])
- checkFilter('a.string < "foo", classOf[Operators.Lt[Binary]])
- checkFilter('a.binary < "foo".getBytes, classOf[Operators.Lt[Binary]])
- }
-
- test("Pushdown LessThanOrEqual predicate") {
- checkFilter('a.int <= 0, classOf[Operators.LtEq[Integer]])
- checkFilter('a.long <= 0.toLong, classOf[Operators.LtEq[java.lang.Long]])
- checkFilter('a.float <= 0.toFloat, classOf[Operators.LtEq[java.lang.Float]])
- checkFilter('a.double <= 0.toDouble, classOf[Operators.LtEq[java.lang.Double]])
- checkFilter('a.string <= "foo", classOf[Operators.LtEq[Binary]])
- checkFilter('a.binary <= "foo".getBytes, classOf[Operators.LtEq[Binary]])
- }
-
- test("Pushdown GreaterThan predicate") {
- checkFilter('a.int > 0, classOf[Operators.Gt[Integer]])
- checkFilter('a.long > 0.toLong, classOf[Operators.Gt[java.lang.Long]])
- checkFilter('a.float > 0.toFloat, classOf[Operators.Gt[java.lang.Float]])
- checkFilter('a.double > 0.toDouble, classOf[Operators.Gt[java.lang.Double]])
- checkFilter('a.string > "foo", classOf[Operators.Gt[Binary]])
- checkFilter('a.binary > "foo".getBytes, classOf[Operators.Gt[Binary]])
- }
-
- test("Pushdown GreaterThanOrEqual predicate") {
- checkFilter('a.int >= 0, classOf[Operators.GtEq[Integer]])
- checkFilter('a.long >= 0.toLong, classOf[Operators.GtEq[java.lang.Long]])
- checkFilter('a.float >= 0.toFloat, classOf[Operators.GtEq[java.lang.Float]])
- checkFilter('a.double >= 0.toDouble, classOf[Operators.GtEq[java.lang.Double]])
- checkFilter('a.string >= "foo", classOf[Operators.GtEq[Binary]])
- checkFilter('a.binary >= "foo".getBytes, classOf[Operators.GtEq[Binary]])
- }
-
- test("Comparison with null should not be pushed down") {
- val predicates = Seq(
- 'a.int === null,
- !('a.int === null),
-
- Literal(null) === 'a.int,
- !(Literal(null) === 'a.int),
-
- 'a.int < null,
- 'a.int <= null,
- 'a.int > null,
- 'a.int >= null,
-
- Literal(null) < 'a.int,
- Literal(null) <= 'a.int,
- Literal(null) > 'a.int,
- Literal(null) >= 'a.int
- )
-
- predicates.foreach { p =>
- assert(
- ParquetFilters.createFilter(p).isEmpty,
- "Comparison predicate with null shouldn't be pushed down")
- }
- }
-
- test("Import of simple Parquet files using glob wildcard pattern") {
- val testGlobDir = ParquetTestData.testGlobDir.toString
- val globPatterns = Array(testGlobDir + "/*/*", testGlobDir + "/spark-*/*", testGlobDir + "/?pa?k-*/*")
- globPatterns.foreach { path =>
- val result = parquetFile(path).collect()
- assert(result.size === 45)
- result.zipWithIndex.foreach {
- case (row, index) => {
- val checkBoolean =
- if ((index % 15) % 3 == 0)
- row(0) == true
- else
- row(0) == false
- assert(checkBoolean === true, s"boolean field value in line $index did not match")
- if ((index % 15) % 5 == 0) assert(row(1) === 5, s"int field value in line $index did not match")
- assert(row(2) === "abc", s"string field value in line $index did not match")
- assert(row(3) === ((index.toLong % 15) << 33), s"long value in line $index did not match")
- assert(row(4) === 2.5F, s"float field value in line $index did not match")
- assert(row(5) === 4.5D, s"double field value in line $index did not match")
- }
- }
+ withParquetTable((1 to 10).map(Tuple1.apply), "t") {
+ checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_)))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala
deleted file mode 100644
index 4c081fb4510b2..0000000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.parquet
-
-import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.catalyst.expressions.Row
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.test.TestSQLContext._
-
-/**
- * A test suite that tests various Parquet queries.
- */
-class ParquetQuerySuite2 extends QueryTest with ParquetTest {
- val sqlContext = TestSQLContext
-
- test("simple projection") {
- withParquetTable((0 until 10).map(i => (i, i.toString)), "t") {
- checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_)))
- }
- }
-
- test("appending") {
- val data = (0 until 10).map(i => (i, i.toString))
- withParquetTable(data, "t") {
- sql("INSERT INTO t SELECT * FROM t")
- checkAnswer(table("t"), data ++ data)
- }
- }
-
- test("self-join") {
- // 4 rows, cells of column 1 of row 2 and row 4 are null
- val data = (1 to 4).map { i =>
- val maybeInt = if (i % 2 == 0) None else Some(i)
- (maybeInt, i.toString)
- }
-
- withParquetTable(data, "t") {
- val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1")
- val queryOutput = selfJoin.queryExecution.analyzed.output
-
- assertResult(4, s"Field count mismatches")(queryOutput.size)
- assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") {
- queryOutput.filter(_.name == "_1").map(_.exprId).size
- }
-
- checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3")))
- }
- }
-
- test("nested data - struct with array field") {
- val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
- withParquetTable(data, "t") {
- checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map {
- case Tuple1((_, Seq(string))) => Row(string)
- })
- }
- }
-
- test("nested data - array of struct") {
- val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i")))
- withParquetTable(data, "t") {
- checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map {
- case Tuple1(Seq((_, string))) => Row(string)
- })
- }
- }
-
- test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") {
- withParquetTable((1 to 10).map(Tuple1.apply), "t") {
- checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_)))
- }
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 264f6d94c4ed9..b1e0919b7aed1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -244,7 +244,7 @@ class TableScanSuite extends DataSourceTest {
sqlTest(
"SELECT count(*) FROM tableWithSchema",
- 10)
+ Seq(Row(10)))
sqlTest(
"SELECT `string$%Field` FROM tableWithSchema",
@@ -260,7 +260,7 @@ class TableScanSuite extends DataSourceTest {
sqlTest(
"SELECT structFieldSimple.key, arrayFieldSimple[1] FROM tableWithSchema a where int_Field=1",
- Seq(Seq(1, 2)))
+ Seq(Row(1, 2)))
sqlTest(
"SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema",
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
index ebf7003ff9e57..3f20c6142e59a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
@@ -20,30 +20,20 @@ package org.apache.spark.sql.hive
import scala.language.implicitConversions
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, SqlLexical}
+import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.hive.execution.{AddJar, AddFile, HiveNativeCommand}
/**
* A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions.
*/
private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser {
- protected implicit def asParser(k: Keyword): Parser[String] =
- lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
-
+ // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
+ // properties via reflection the class in runtime for constructing the SqlLexical object
protected val ADD = Keyword("ADD")
protected val DFS = Keyword("DFS")
protected val FILE = Keyword("FILE")
protected val JAR = Keyword("JAR")
- private val reservedWords =
- this
- .getClass
- .getMethods
- .filter(_.getReturnType == classOf[Keyword])
- .map(_.invoke(this).asInstanceOf[Keyword].str)
-
- override val lexical = new SqlLexical(reservedWords)
-
protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl
protected lazy val hiveQl: Parser[LogicalPlan] =
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 10833c113216a..9d2cfd8e0d669 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -29,6 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.metadata.Table
import org.apache.hadoop.hive.ql.processors._
+import org.apache.hadoop.hive.ql.parse.VariableSubstitution
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable}
@@ -66,11 +67,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
new this.QueryExecution { val logical = plan }
override def sql(sqlText: String): SchemaRDD = {
+ val substituted = new VariableSubstitution().substitute(hiveconf, sqlText)
// TODO: Create a framework for registering parsers instead of just hardcoding if statements.
if (conf.dialect == "sql") {
- super.sql(sqlText)
+ super.sql(substituted)
} else if (conf.dialect == "hiveql") {
- new SchemaRDD(this, ddlParser(sqlText).getOrElse(HiveQl.parseSql(sqlText)))
+ new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted)))
} else {
sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'")
}
@@ -368,10 +370,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
.mkString("\t")
}
case command: ExecutedCommand =>
- command.executeCollect().map(_.head.toString)
+ command.executeCollect().map(_(0).toString)
case other =>
- val result: Seq[Seq[Any]] = other.executeCollect().toSeq
+ val result: Seq[Seq[Any]] = other.executeCollect().map(_.toSeq).toSeq
// We need the types so we can output struct field names
val types = analyzed.output.map(_.dataType)
// Reformat to match hive tab delimited output.
@@ -395,7 +397,7 @@ private object HiveContext {
protected[sql] def toHiveString(a: (Any, DataType)): String = a match {
case (struct: Row, StructType(fields)) =>
- struct.zip(fields).map {
+ struct.toSeq.zip(fields).map {
case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
}.mkString("{", ",", "}")
case (seq: Seq[_], ArrayType(typ, _)) =>
@@ -418,7 +420,7 @@ private object HiveContext {
/** Hive outputs fields of structs slightly differently than top level attributes. */
protected def toHiveStructString(a: (Any, DataType)): String = a match {
case (struct: Row, StructType(fields)) =>
- struct.zip(fields).map {
+ struct.toSeq.zip(fields).map {
case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
}.mkString("{", ",", "}")
case (seq: Seq[_], ArrayType(typ, _)) =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index eeabfdd857916..82dba99900df9 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -348,7 +348,7 @@ private[hive] trait HiveInspectors {
(o: Any) => {
if (o != null) {
val struct = soi.create()
- (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach {
+ (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row].toSeq).zipped.foreach {
(field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data))
}
struct
@@ -432,7 +432,7 @@ private[hive] trait HiveInspectors {
}
case x: SettableStructObjectInspector =>
val fieldRefs = x.getAllStructFieldRefs
- val row = a.asInstanceOf[Seq[_]]
+ val row = a.asInstanceOf[Row]
// 1. create the pojo (most likely) object
val result = x.create()
var i = 0
@@ -448,7 +448,7 @@ private[hive] trait HiveInspectors {
result
case x: StructObjectInspector =>
val fieldRefs = x.getAllStructFieldRefs
- val row = a.asInstanceOf[Seq[_]]
+ val row = a.asInstanceOf[Row]
val result = new java.util.ArrayList[AnyRef](fieldRefs.length)
var i = 0
while (i < fieldRefs.length) {
@@ -475,7 +475,7 @@ private[hive] trait HiveInspectors {
}
def wrap(
- row: Seq[Any],
+ row: Row,
inspectors: Seq[ObjectInspector],
cache: Array[AnyRef]): Array[AnyRef] = {
var i = 0
@@ -486,6 +486,18 @@ private[hive] trait HiveInspectors {
cache
}
+ def wrap(
+ row: Seq[Any],
+ inspectors: Seq[ObjectInspector],
+ cache: Array[AnyRef]): Array[AnyRef] = {
+ var i = 0
+ while (i < inspectors.length) {
+ cache(i) = wrap(row(i), inspectors(i))
+ i += 1
+ }
+ cache
+ }
+
/**
* @param dataType Catalyst data type
* @return Hive java object inspector (recursively), not the Writable ObjectInspector
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index d898b876c39f8..76d2140372197 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -360,7 +360,7 @@ private[hive] case class HiveUdafFunction(
protected lazy val cached = new Array[AnyRef](exprs.length)
def update(input: Row): Unit = {
- val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray
+ val inputs = inputProjection(input)
function.iterate(buffer, wrap(inputs, inspectors, cached))
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
index cc8bb3e172c6e..aae175e426ade 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
@@ -209,7 +209,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer(
override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = {
val dynamicPartPath = dynamicPartColNames
- .zip(row.takeRight(dynamicPartColNames.length))
+ .zip(row.toSeq.takeRight(dynamicPartColNames.length))
.map { case (col, rawVal) =>
val string = if (rawVal == null) null else String.valueOf(rawVal)
s"/$col=${if (string == null || string.isEmpty) defaultPartName else string}"
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
index f89c49d292c6c..f320d732fb77a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util._
* So, we duplicate this code here.
*/
class QueryTest extends PlanTest {
+
/**
* Runs the plan and makes sure the answer contains all of the keywords, or the
* none of keywords are listed in the answer
@@ -56,17 +57,20 @@ class QueryTest extends PlanTest {
* @param rdd the [[SchemaRDD]] to be executed
* @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
*/
- protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Any): Unit = {
- val convertedAnswer = expectedAnswer match {
- case s: Seq[_] if s.isEmpty => s
- case s: Seq[_] if s.head.isInstanceOf[Product] &&
- !s.head.isInstanceOf[Seq[_]] => s.map(_.asInstanceOf[Product].productIterator.toIndexedSeq)
- case s: Seq[_] => s
- case singleItem => Seq(Seq(singleItem))
+ protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = {
+ val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
+ def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
+ // Converts data to types that we can do equality comparison using Scala collections.
+ // For BigDecimal type, the Scala type has a better definition of equality test (similar to
+ // Java's java.math.BigDecimal.compareTo).
+ val converted: Seq[Row] = answer.map { s =>
+ Row.fromSeq(s.toSeq.map {
+ case d: java.math.BigDecimal => BigDecimal(d)
+ case o => o
+ })
+ }
+ if (!isSorted) converted.sortBy(_.toString) else converted
}
-
- val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s}.nonEmpty
- def prepareAnswer(answer: Seq[Any]) = if (!isSorted) answer.sortBy(_.toString) else answer
val sparkAnswer = try rdd.collect().toSeq catch {
case e: Exception =>
fail(
@@ -74,11 +78,12 @@ class QueryTest extends PlanTest {
|Exception thrown while executing query:
|${rdd.queryExecution}
|== Exception ==
- |${stackTraceToString(e)}
+ |$e
+ |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
""".stripMargin)
}
- if(prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) {
+ if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
fail(s"""
|Results do not match for query:
|${rdd.logicalPlan}
@@ -88,11 +93,22 @@ class QueryTest extends PlanTest {
|${rdd.queryExecution.executedPlan}
|== Results ==
|${sideBySide(
- s"== Correct Answer - ${convertedAnswer.size} ==" +:
- prepareAnswer(convertedAnswer).map(_.toString),
- s"== Spark Answer - ${sparkAnswer.size} ==" +:
- prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
+ s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer).map(_.toString),
+ s"== Spark Answer - ${sparkAnswer.size} ==" +:
+ prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
""".stripMargin)
}
}
+
+ protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = {
+ checkAnswer(rdd, Seq(expectedAnswer))
+ }
+
+ def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
+ test(sqlString) {
+ checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
+ }
+ }
+
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
index 4864607252034..2d3ff680125ad 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
@@ -129,6 +129,12 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors {
}
}
+ def checkValues(row1: Seq[Any], row2: Row): Unit = {
+ row1.zip(row2.toSeq).map {
+ case (r1, r2) => checkValue(r1, r2)
+ }
+ }
+
def checkValue(v1: Any, v2: Any): Unit = {
(v1, v2) match {
case (r1: Decimal, r2: Decimal) =>
@@ -198,7 +204,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors {
case (t, idx) => StructField(s"c_$idx", t)
})
- checkValues(row, unwrap(wrap(row, toInspector(dt)), toInspector(dt)).asInstanceOf[Row])
+ checkValues(row, unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row])
checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt)))
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 7cfb875e05db3..0e6636d38ed3c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -43,7 +43,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
// Make sure the table has also been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.collect().toSeq
+ testData.collect().toSeq.map(Row.fromTuple)
)
// Add more data.
@@ -52,7 +52,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
// Make sure the table has been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.collect().toSeq ++ testData.collect().toSeq
+ testData.toSchemaRDD.collect().toSeq ++ testData.toSchemaRDD.collect().toSeq
)
// Now overwrite.
@@ -61,7 +61,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
// Make sure the registered table has also been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.collect().toSeq
+ testData.collect().toSeq.map(Row.fromTuple)
)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 53d8aa7739bc2..7408c7ffd69e8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -155,7 +155,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
checkAnswer(
sql("SELECT * FROM jsonTable"),
- ("a", "b") :: Nil)
+ Row("a", "b"))
FileUtils.deleteDirectory(tempDir)
sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
@@ -164,14 +164,14 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
// will show.
checkAnswer(
sql("SELECT * FROM jsonTable"),
- ("a1", "b1") :: Nil)
+ Row("a1", "b1"))
refreshTable("jsonTable")
// Check that the refresh worked
checkAnswer(
sql("SELECT * FROM jsonTable"),
- ("a1", "b1", "c1") :: Nil)
+ Row("a1", "b1", "c1"))
FileUtils.deleteDirectory(tempDir)
}
@@ -191,7 +191,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
checkAnswer(
sql("SELECT * FROM jsonTable"),
- ("a", "b") :: Nil)
+ Row("a", "b"))
FileUtils.deleteDirectory(tempDir)
sparkContext.parallelize(("a", "b", "c") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
@@ -210,7 +210,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
// New table should reflect new schema.
checkAnswer(
sql("SELECT * FROM jsonTable"),
- ("a", "b", "c") :: Nil)
+ Row("a", "b", "c"))
FileUtils.deleteDirectory(tempDir)
}
@@ -253,6 +253,6 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
|)
""".stripMargin)
- sql("DROP TABLE jsonTable").collect.foreach(println)
+ sql("DROP TABLE jsonTable").collect().foreach(println)
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 0b4e76c9d3d2f..6f07fd5a879c0 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterAll
import scala.reflect.ClassTag
-import org.apache.spark.sql.{SQLConf, QueryTest}
+import org.apache.spark.sql.{Row, SQLConf, QueryTest}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
@@ -141,7 +141,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
before: () => Unit,
after: () => Unit,
query: String,
- expectedAnswer: Seq[Any],
+ expectedAnswer: Seq[Row],
ct: ClassTag[_]) = {
before()
@@ -183,7 +183,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
/** Tests for MetastoreRelation */
val metastoreQuery = """SELECT * FROM src a JOIN src b ON a.key = 238 AND a.key = b.key"""
- val metastoreAnswer = Seq.fill(4)((238, "val_238", 238, "val_238"))
+ val metastoreAnswer = Seq.fill(4)(Row(238, "val_238", 238, "val_238"))
mkTest(
() => (),
() => (),
@@ -197,7 +197,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
val leftSemiJoinQuery =
"""SELECT * FROM src a
|left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin
- val answer = (86, "val_86") :: Nil
+ val answer = Row(86, "val_86")
var rdd = sql(leftSemiJoinQuery)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index c14f0d24e0dc3..df72be7746ac6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -226,7 +226,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
// Jdk version leads to different query output for double, so not use createQueryTest here
test("division") {
val res = sql("SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1").collect().head
- Seq(2.0, 0.5, 0.3333333333333333, 0.002).zip(res).foreach( x =>
+ Seq(2.0, 0.5, 0.3333333333333333, 0.002).zip(res.toSeq).foreach( x =>
assert(x._1 == x._2.asInstanceOf[Double]))
}
@@ -235,7 +235,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
test("Query expressed in SQL") {
setConf("spark.sql.dialect", "sql")
- assert(sql("SELECT 1").collect() === Array(Seq(1)))
+ assert(sql("SELECT 1").collect() === Array(Row(1)))
setConf("spark.sql.dialect", "hiveql")
}
@@ -467,7 +467,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
TestData(2, "str2") :: Nil)
testData.registerTempTable("REGisteredTABle")
- assertResult(Array(Array(2, "str2"))) {
+ assertResult(Array(Row(2, "str2"))) {
sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " +
"WHERE TableAliaS.a > 1").collect()
}
@@ -553,12 +553,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
// Describe a table
assertResult(
Array(
- Array("key", "int", null),
- Array("value", "string", null),
- Array("dt", "string", null),
- Array("# Partition Information", "", ""),
- Array("# col_name", "data_type", "comment"),
- Array("dt", "string", null))
+ Row("key", "int", null),
+ Row("value", "string", null),
+ Row("dt", "string", null),
+ Row("# Partition Information", "", ""),
+ Row("# col_name", "data_type", "comment"),
+ Row("dt", "string", null))
) {
sql("DESCRIBE test_describe_commands1")
.select('col_name, 'data_type, 'comment)
@@ -568,12 +568,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
// Describe a table with a fully qualified table name
assertResult(
Array(
- Array("key", "int", null),
- Array("value", "string", null),
- Array("dt", "string", null),
- Array("# Partition Information", "", ""),
- Array("# col_name", "data_type", "comment"),
- Array("dt", "string", null))
+ Row("key", "int", null),
+ Row("value", "string", null),
+ Row("dt", "string", null),
+ Row("# Partition Information", "", ""),
+ Row("# col_name", "data_type", "comment"),
+ Row("dt", "string", null))
) {
sql("DESCRIBE default.test_describe_commands1")
.select('col_name, 'data_type, 'comment)
@@ -623,8 +623,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
assertResult(
Array(
- Array("a", "IntegerType", null),
- Array("b", "StringType", null))
+ Row("a", "IntegerType", null),
+ Row("b", "StringType", null))
) {
sql("DESCRIBE test_describe_commands2")
.select('col_name, 'data_type, 'comment)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index 5dafcd6c0a76a..f2374a215291b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -64,7 +64,7 @@ class HiveUdfSuite extends QueryTest {
test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") {
checkAnswer(
sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"),
- 8
+ Row(8)
)
}
@@ -115,7 +115,7 @@ class HiveUdfSuite extends QueryTest {
sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'")
checkAnswer(
sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(),
- Seq(Seq("1"), Seq("2")))
+ Seq(Row("1"), Row("2")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString")
TestHive.reset()
@@ -131,7 +131,7 @@ class HiveUdfSuite extends QueryTest {
sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'")
checkAnswer(
sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(),
- Seq(Seq(0), Seq(2), Seq(13)))
+ Seq(Row(0), Row(2), Row(13)))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt")
TestHive.reset()
@@ -146,7 +146,7 @@ class HiveUdfSuite extends QueryTest {
sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'")
checkAnswer(
sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(),
- Seq(Seq("a,b,c"), Seq("d,e")))
+ Seq(Row("a,b,c"), Row("d,e")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString")
TestHive.reset()
@@ -160,7 +160,7 @@ class HiveUdfSuite extends QueryTest {
sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'")
checkAnswer(
sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(),
- Seq(Seq("hello world"), Seq("hello goodbye")))
+ Seq(Row("hello world"), Row("hello goodbye")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf")
TestHive.reset()
@@ -177,7 +177,7 @@ class HiveUdfSuite extends QueryTest {
sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'")
checkAnswer(
sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(),
- Seq(Seq("0, 0"), Seq("2, 2"), Seq("13, 13")))
+ Seq(Row("0, 0"), Row("2, 2"), Row("13, 13")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList")
TestHive.reset()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index d41eb9e870bf0..7f9f1ac7cd80d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -41,7 +41,7 @@ class SQLQuerySuite extends QueryTest {
}
test("CTAS with serde") {
- sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect
+ sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect()
sql(
"""CREATE TABLE ctas2
| ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"
@@ -51,23 +51,23 @@ class SQLQuerySuite extends QueryTest {
| AS
| SELECT key, value
| FROM src
- | ORDER BY key, value""".stripMargin).collect
+ | ORDER BY key, value""".stripMargin).collect()
sql(
"""CREATE TABLE ctas3
| ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\012'
| STORED AS textfile AS
| SELECT key, value
| FROM src
- | ORDER BY key, value""".stripMargin).collect
+ | ORDER BY key, value""".stripMargin).collect()
// the table schema may like (key: integer, value: string)
sql(
"""CREATE TABLE IF NOT EXISTS ctas4 AS
- | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect
+ | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect()
// do nothing cause the table ctas4 already existed.
sql(
"""CREATE TABLE IF NOT EXISTS ctas4 AS
- | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect
+ | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect()
checkAnswer(
sql("SELECT k, value FROM ctas1 ORDER BY k, value"),
@@ -89,7 +89,7 @@ class SQLQuerySuite extends QueryTest {
intercept[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] {
sql(
"""CREATE TABLE ctas4 AS
- | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect
+ | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect()
}
checkAnswer(
sql("SELECT key, value FROM ctas4 ORDER BY key, value"),
@@ -104,6 +104,24 @@ class SQLQuerySuite extends QueryTest {
)
}
+ test("command substitution") {
+ sql("set tbl=src")
+ checkAnswer(
+ sql("SELECT key FROM ${hiveconf:tbl} ORDER BY key, value limit 1"),
+ sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq)
+
+ sql("set hive.variable.substitute=false") // disable the substitution
+ sql("set tbl2=src")
+ intercept[Exception] {
+ sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1").collect()
+ }
+
+ sql("set hive.variable.substitute=true") // enable the substitution
+ checkAnswer(
+ sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1"),
+ sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq)
+ }
+
test("ordering not in select") {
checkAnswer(
sql("SELECT key FROM src ORDER BY value"),
@@ -126,7 +144,7 @@ class SQLQuerySuite extends QueryTest {
sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested")
checkAnswer(
sql("SELECT f1.f2.f3 FROM nested"),
- 1)
+ Row(1))
checkAnswer(sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested"),
Seq.empty[Row])
checkAnswer(
@@ -233,7 +251,7 @@ class SQLQuerySuite extends QueryTest {
| (s struct,
| innerArray:array,
| innerMap: map>)
- """.stripMargin).collect
+ """.stripMargin).collect()
sql(
"""
@@ -243,7 +261,7 @@ class SQLQuerySuite extends QueryTest {
checkAnswer(
sql("SELECT * FROM nullValuesInInnerComplexTypes"),
- Seq(Seq(Seq(null, null, null)))
+ Row(Row(null, null, null))
)
sql("DROP TABLE nullValuesInInnerComplexTypes")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala
index 4bc14bad0ad5f..581f666399492 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala
@@ -39,7 +39,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest {
test("SELECT on Parquet table") {
val data = (1 to 4).map(i => (i, s"val_$i"))
withParquetTable(data, "t") {
- checkAnswer(sql("SELECT * FROM t"), data)
+ checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple))
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
index 8bbb7f2fdbf48..79fd99d9f89ff 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
@@ -177,81 +177,81 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
test(s"ordering of the partitioning columns $table") {
checkAnswer(
sql(s"SELECT p, stringField FROM $table WHERE p = 1"),
- Seq.fill(10)((1, "part-1"))
+ Seq.fill(10)(Row(1, "part-1"))
)
checkAnswer(
sql(s"SELECT stringField, p FROM $table WHERE p = 1"),
- Seq.fill(10)(("part-1", 1))
+ Seq.fill(10)(Row("part-1", 1))
)
}
test(s"project the partitioning column $table") {
checkAnswer(
sql(s"SELECT p, count(*) FROM $table group by p"),
- (1, 10) ::
- (2, 10) ::
- (3, 10) ::
- (4, 10) ::
- (5, 10) ::
- (6, 10) ::
- (7, 10) ::
- (8, 10) ::
- (9, 10) ::
- (10, 10) :: Nil
+ Row(1, 10) ::
+ Row(2, 10) ::
+ Row(3, 10) ::
+ Row(4, 10) ::
+ Row(5, 10) ::
+ Row(6, 10) ::
+ Row(7, 10) ::
+ Row(8, 10) ::
+ Row(9, 10) ::
+ Row(10, 10) :: Nil
)
}
test(s"project partitioning and non-partitioning columns $table") {
checkAnswer(
sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"),
- ("part-1", 1, 10) ::
- ("part-2", 2, 10) ::
- ("part-3", 3, 10) ::
- ("part-4", 4, 10) ::
- ("part-5", 5, 10) ::
- ("part-6", 6, 10) ::
- ("part-7", 7, 10) ::
- ("part-8", 8, 10) ::
- ("part-9", 9, 10) ::
- ("part-10", 10, 10) :: Nil
+ Row("part-1", 1, 10) ::
+ Row("part-2", 2, 10) ::
+ Row("part-3", 3, 10) ::
+ Row("part-4", 4, 10) ::
+ Row("part-5", 5, 10) ::
+ Row("part-6", 6, 10) ::
+ Row("part-7", 7, 10) ::
+ Row("part-8", 8, 10) ::
+ Row("part-9", 9, 10) ::
+ Row("part-10", 10, 10) :: Nil
)
}
test(s"simple count $table") {
checkAnswer(
sql(s"SELECT COUNT(*) FROM $table"),
- 100)
+ Row(100))
}
test(s"pruned count $table") {
checkAnswer(
sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"),
- 10)
+ Row(10))
}
test(s"non-existant partition $table") {
checkAnswer(
sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"),
- 0)
+ Row(0))
}
test(s"multi-partition pruned count $table") {
checkAnswer(
sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"),
- 30)
+ Row(30))
}
test(s"non-partition predicates $table") {
checkAnswer(
sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"),
- 30)
+ Row(30))
}
test(s"sum $table") {
checkAnswer(
sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"),
- 1 + 2 + 3)
+ Row(1 + 2 + 3))
}
test(s"hive udfs $table") {
@@ -266,6 +266,6 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
test("non-part select(*)") {
checkAnswer(
sql("SELECT COUNT(*) FROM normal_parquet"),
- 10)
+ Row(10))
}
}
diff --git a/streaming/pom.xml b/streaming/pom.xml
index d3c6d0347a622..22b0d714b57f6 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -96,5 +96,13 @@
+
+
+ ../python
+
+ pyspark/streaming/*.py
+
+
+
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index e59c24adb84af..0e285d6088ec1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -160,6 +160,14 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
}
}
+ /**
+ * Get the maximum remember duration across all the input streams. This is a conservative but
+ * safe remember duration which can be used to perform cleanup operations.
+ */
+ def getMaxInputStreamRememberDuration(): Duration = {
+ inputStreams.map { _.rememberDuration }.maxBy { _.milliseconds }
+ }
+
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
logDebug("DStreamGraph.writeObject used")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index d8695b8e05962..9a2254bcdc1f7 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -17,14 +17,15 @@
package org.apache.spark.streaming.api.java
+import java.lang.{Boolean => JBoolean}
+import java.io.{Closeable, InputStream}
+import java.util.{List => JList, Map => JMap}
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
-import java.io.{Closeable, InputStream}
-import java.util.{List => JList, Map => JMap}
-
import akka.actor.{Props, SupervisorStrategy}
+import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.spark.{SparkConf, SparkContext}
@@ -250,21 +251,53 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable {
* Files must be written to the monitored directory by "moving" them from another
* location within the same file system. File names starting with . are ignored.
* @param directory HDFS directory to monitor for new file
+ * @param kClass class of key for reading HDFS file
+ * @param vClass class of value for reading HDFS file
+ * @param fClass class of input format for reading HDFS file
* @tparam K Key type for reading HDFS file
* @tparam V Value type for reading HDFS file
* @tparam F Input format for reading HDFS file
*/
def fileStream[K, V, F <: NewInputFormat[K, V]](
- directory: String): JavaPairInputDStream[K, V] = {
- implicit val cmk: ClassTag[K] =
- implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
- implicit val cmv: ClassTag[V] =
- implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
- implicit val cmf: ClassTag[F] =
- implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[F]]
+ directory: String,
+ kClass: Class[K],
+ vClass: Class[V],
+ fClass: Class[F]): JavaPairInputDStream[K, V] = {
+ implicit val cmk: ClassTag[K] = ClassTag(kClass)
+ implicit val cmv: ClassTag[V] = ClassTag(vClass)
+ implicit val cmf: ClassTag[F] = ClassTag(fClass)
ssc.fileStream[K, V, F](directory)
}
+ /**
+ * Create an input stream that monitors a Hadoop-compatible filesystem
+ * for new files and reads them using the given key-value types and input format.
+ * Files must be written to the monitored directory by "moving" them from another
+ * location within the same file system. File names starting with . are ignored.
+ * @param directory HDFS directory to monitor for new file
+ * @param kClass class of key for reading HDFS file
+ * @param vClass class of value for reading HDFS file
+ * @param fClass class of input format for reading HDFS file
+ * @param filter Function to filter paths to process
+ * @param newFilesOnly Should process only new files and ignore existing files in the directory
+ * @tparam K Key type for reading HDFS file
+ * @tparam V Value type for reading HDFS file
+ * @tparam F Input format for reading HDFS file
+ */
+ def fileStream[K, V, F <: NewInputFormat[K, V]](
+ directory: String,
+ kClass: Class[K],
+ vClass: Class[V],
+ fClass: Class[F],
+ filter: JFunction[Path, JBoolean],
+ newFilesOnly: Boolean): JavaPairInputDStream[K, V] = {
+ implicit val cmk: ClassTag[K] = ClassTag(kClass)
+ implicit val cmv: ClassTag[V] = ClassTag(vClass)
+ implicit val cmf: ClassTag[F] = ClassTag(fClass)
+ def fn = (x: Path) => filter.call(x).booleanValue()
+ ssc.fileStream[K, V, F](directory, fn, newFilesOnly)
+ }
+
/**
* Create an input stream with any arbitrary user implemented actor receiver.
* @param props Props object defining creation of the actor
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
index afd3c4bc4c4fe..8be04314c4285 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
@@ -94,15 +94,4 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
}
Some(blockRDD)
}
-
- /**
- * Clear metadata that are older than `rememberDuration` of this DStream.
- * This is an internal method that should not be called directly. This
- * implementation overrides the default implementation to clear received
- * block information.
- */
- private[streaming] override def clearMetadata(time: Time) {
- super.clearMetadata(time)
- ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration)
- }
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala
index ab9fa192191aa..7bf3c33319491 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala
@@ -17,7 +17,10 @@
package org.apache.spark.streaming.receiver
-/** Messages sent to the NetworkReceiver. */
+import org.apache.spark.streaming.Time
+
+/** Messages sent to the Receiver. */
private[streaming] sealed trait ReceiverMessage extends Serializable
private[streaming] object StopReceiver extends ReceiverMessage
+private[streaming] case class CleanupOldBlocks(threshTime: Time) extends ReceiverMessage
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index d7229c2b96d0b..716cf2c7f32fc 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -30,6 +30,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.{Logging, SparkEnv, SparkException}
import org.apache.spark.storage.StreamBlockId
+import org.apache.spark.streaming.Time
import org.apache.spark.streaming.scheduler._
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -82,6 +83,9 @@ private[streaming] class ReceiverSupervisorImpl(
case StopReceiver =>
logInfo("Received stop signal")
stop("Stopped by driver", None)
+ case CleanupOldBlocks(threshTime) =>
+ logDebug("Received delete old batch signal")
+ cleanupOldBlocks(threshTime)
}
def ref = self
@@ -193,4 +197,9 @@ private[streaming] class ReceiverSupervisorImpl(
/** Generate new block ID */
private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement)
+
+ private def cleanupOldBlocks(cleanupThreshTime: Time): Unit = {
+ logDebug(s"Cleaning up blocks older then $cleanupThreshTime")
+ receivedBlockHandler.cleanupOldBlocks(cleanupThreshTime.milliseconds)
+ }
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 39b66e1130768..d86f852aba97e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -238,13 +238,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
/** Clear DStream metadata for the given `time`. */
private def clearMetadata(time: Time) {
ssc.graph.clearMetadata(time)
- jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration)
// If checkpointing is enabled, then checkpoint,
// else mark batch to be fully processed
if (shouldCheckpoint) {
eventActor ! DoCheckpoint(time)
} else {
+ // If checkpointing is not enabled, then delete metadata information about
+ // received blocks (block data not saved in any case). Otherwise, wait for
+ // checkpointing of this batch to complete.
+ val maxRememberDuration = graph.getMaxInputStreamRememberDuration()
+ jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration)
markBatchFullyProcessed(time)
}
}
@@ -252,6 +256,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
/** Clear DStream checkpoint data for the given `time`. */
private def clearCheckpointData(time: Time) {
ssc.graph.clearCheckpointData(time)
+
+ // All the checkpoint information about which batches have been processed, etc have
+ // been saved to checkpoints, so its safe to delete block metadata and data WAL files
+ val maxRememberDuration = graph.getMaxInputStreamRememberDuration()
+ jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration)
markBatchFullyProcessed(time)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
index c3d9d7b6813d3..ef23b5c79f2e1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
@@ -150,7 +150,6 @@ private[streaming] class ReceivedBlockTracker(
writeToLog(BatchCleanupEvent(timesToCleanup))
timeToAllocatedBlocks --= timesToCleanup
logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds, waitForCompletion))
- log
}
/** Stop the block tracker. */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index 8dbb42a86e3bd..4f998869731ed 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -24,9 +24,8 @@ import scala.language.existentials
import akka.actor._
import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException}
-import org.apache.spark.SparkContext._
import org.apache.spark.streaming.{StreamingContext, Time}
-import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver}
+import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver}
/**
* Messages used by the NetworkReceiver and the ReceiverTracker to communicate
@@ -119,9 +118,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
}
}
- /** Clean up metadata older than the given threshold time */
- def cleanupOldMetadata(cleanupThreshTime: Time) {
+ /**
+ * Clean up the data and metadata of blocks and batches that are strictly
+ * older than the threshold time. Note that this does not
+ */
+ def cleanupOldBlocksAndBatches(cleanupThreshTime: Time) {
+ // Clean up old block and batch metadata
receivedBlockTracker.cleanupOldBatches(cleanupThreshTime, waitForCompletion = false)
+
+ // Signal the receivers to delete old block data
+ if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) {
+ logInfo(s"Cleanup old received batch data: $cleanupThreshTime")
+ receiverInfo.values.flatMap { info => Option(info.actor) }
+ .foreach { _ ! CleanupOldBlocks(cleanupThreshTime) }
+ }
}
/** Register a receiver */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala
index 27a28bab83ed5..858ba3c9eb4e5 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala
@@ -63,7 +63,7 @@ private[streaming] object HdfsUtils {
}
def getFileSystemForPath(path: Path, conf: Configuration): FileSystem = {
- // For local file systems, return the raw loca file system, such calls to flush()
+ // For local file systems, return the raw local file system, such calls to flush()
// actually flushes the stream.
val fs = path.getFileSystem(conf)
fs match {
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index 12cc0de7509d6..d92e7fe899a09 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -17,13 +17,20 @@
package org.apache.spark.streaming;
+import java.io.*;
+import java.lang.Iterable;
+import java.nio.charset.Charset;
+import java.util.*;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import scala.Tuple2;
import org.junit.Assert;
+import static org.junit.Assert.*;
import org.junit.Test;
-import java.io.*;
-import java.util.*;
-import java.lang.Iterable;
import com.google.common.base.Optional;
import com.google.common.collect.Lists;
@@ -1743,13 +1750,66 @@ public Iterable call(InputStream in) throws IOException {
StorageLevel.MEMORY_ONLY());
}
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testTextFileStream() throws IOException {
+ File testDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"));
+ List> expected = fileTestPrepare(testDir);
+
+ JavaDStream input = ssc.textFileStream(testDir.toString());
+ JavaTestUtils.attachTestOutputStream(input);
+ List> result = JavaTestUtils.runStreams(ssc, 1, 1);
+
+ assertOrderInvariantEquals(expected, result);
+ }
+
+ @SuppressWarnings("unchecked")
@Test
- public void testTextFileStream() {
- JavaDStream test = ssc.textFileStream("/tmp/foo");
+ public void testFileStream() throws IOException {
+ File testDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"));
+ List> expected = fileTestPrepare(testDir);
+
+ JavaPairInputDStream inputStream = ssc.fileStream(
+ testDir.toString(),
+ LongWritable.class,
+ Text.class,
+ TextInputFormat.class,
+ new Function() {
+ @Override
+ public Boolean call(Path v1) throws Exception {
+ return Boolean.TRUE;
+ }
+ },
+ true);
+
+ JavaDStream test = inputStream.map(
+ new Function, String>() {
+ @Override
+ public String call(Tuple2 v1) throws Exception {
+ return v1._2().toString();
+ }
+ });
+
+ JavaTestUtils.attachTestOutputStream(test);
+ List> result = JavaTestUtils.runStreams(ssc, 1, 1);
+
+ assertOrderInvariantEquals(expected, result);
}
@Test
public void testRawSocketStream() {
JavaReceiverInputDStream test = ssc.rawSocketStream("localhost", 12345);
}
+
+ private List> fileTestPrepare(File testDir) throws IOException {
+ File existingFile = new File(testDir, "0");
+ Files.write("0\n", existingFile, Charset.forName("UTF-8"));
+ assertTrue(existingFile.setLastModified(1000) && existingFile.lastModified() == 1000);
+
+ List> expected = Arrays.asList(
+ Arrays.asList("0")
+ );
+
+ return expected;
+ }
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
index e26c0c6859e57..e8c34a9ee40b9 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
@@ -17,21 +17,26 @@
package org.apache.spark.streaming
+import java.io.File
import java.nio.ByteBuffer
import java.util.concurrent.Semaphore
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.SparkConf
-import org.apache.spark.storage.{StorageLevel, StreamBlockId}
-import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver, ReceiverSupervisor}
-import org.scalatest.FunSuite
+import com.google.common.io.Files
import org.scalatest.concurrent.Timeouts
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
+import org.apache.spark.SparkConf
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.StreamBlockId
+import org.apache.spark.streaming.receiver._
+import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._
+
/** Testsuite for testing the network receiver behavior */
-class ReceiverSuite extends FunSuite with Timeouts {
+class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
test("receiver life cycle") {
@@ -192,7 +197,6 @@ class ReceiverSuite extends FunSuite with Timeouts {
val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3
val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1
val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",")
- println(minExpectedMessagesPerBlock, maxExpectedMessagesPerBlock, ":", receivedBlockSizes)
assert(
// the first and last block may be incomplete, so we slice them out
recordedBlocks.drop(1).dropRight(1).forall { block =>
@@ -203,39 +207,91 @@ class ReceiverSuite extends FunSuite with Timeouts {
)
}
-
/**
- * An implementation of NetworkReceiver that is used for testing a receiver's life cycle.
+ * Test whether write ahead logs are generated by received,
+ * and automatically cleaned up. The clean up must be aware of the
+ * remember duration of the input streams. E.g., input streams on which window()
+ * has been applied must remember the data for longer, and hence corresponding
+ * WALs should be cleaned later.
*/
- class FakeReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
- @volatile var otherThread: Thread = null
- @volatile var receiving = false
- @volatile var onStartCalled = false
- @volatile var onStopCalled = false
-
- def onStart() {
- otherThread = new Thread() {
- override def run() {
- receiving = true
- while(!isStopped()) {
- Thread.sleep(10)
- }
+ test("write ahead log - generating and cleaning") {
+ val sparkConf = new SparkConf()
+ .setMaster("local[4]") // must be at least 3 as we are going to start 2 receivers
+ .setAppName(framework)
+ .set("spark.ui.enabled", "true")
+ .set("spark.streaming.receiver.writeAheadLog.enable", "true")
+ .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1")
+ val batchDuration = Milliseconds(500)
+ val tempDirectory = Files.createTempDir()
+ val logDirectory1 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 0))
+ val logDirectory2 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 1))
+ val allLogFiles1 = new mutable.HashSet[String]()
+ val allLogFiles2 = new mutable.HashSet[String]()
+ logInfo("Temp checkpoint directory = " + tempDirectory)
+
+ def getBothCurrentLogFiles(): (Seq[String], Seq[String]) = {
+ (getCurrentLogFiles(logDirectory1), getCurrentLogFiles(logDirectory2))
+ }
+
+ def getCurrentLogFiles(logDirectory: File): Seq[String] = {
+ try {
+ if (logDirectory.exists()) {
+ logDirectory1.listFiles().filter { _.getName.startsWith("log") }.map { _.toString }
+ } else {
+ Seq.empty
}
+ } catch {
+ case e: Exception =>
+ Seq.empty
}
- onStartCalled = true
- otherThread.start()
-
}
- def onStop() {
- onStopCalled = true
- otherThread.join()
+ def printLogFiles(message: String, files: Seq[String]) {
+ logInfo(s"$message (${files.size} files):\n" + files.mkString("\n"))
}
- def reset() {
- receiving = false
- onStartCalled = false
- onStopCalled = false
+ withStreamingContext(new StreamingContext(sparkConf, batchDuration)) { ssc =>
+ tempDirectory.deleteOnExit()
+ val receiver1 = ssc.sparkContext.clean(new FakeReceiver(sendData = true))
+ val receiver2 = ssc.sparkContext.clean(new FakeReceiver(sendData = true))
+ val receiverStream1 = ssc.receiverStream(receiver1)
+ val receiverStream2 = ssc.receiverStream(receiver2)
+ receiverStream1.register()
+ receiverStream2.window(batchDuration * 6).register() // 3 second window
+ ssc.checkpoint(tempDirectory.getAbsolutePath())
+ ssc.start()
+
+ // Run until sufficient WAL files have been generated and
+ // the first WAL files has been deleted
+ eventually(timeout(20 seconds), interval(batchDuration.milliseconds millis)) {
+ val (logFiles1, logFiles2) = getBothCurrentLogFiles()
+ allLogFiles1 ++= logFiles1
+ allLogFiles2 ++= logFiles2
+ if (allLogFiles1.size > 0) {
+ assert(!logFiles1.contains(allLogFiles1.toSeq.sorted.head))
+ }
+ if (allLogFiles2.size > 0) {
+ assert(!logFiles2.contains(allLogFiles2.toSeq.sorted.head))
+ }
+ assert(allLogFiles1.size >= 7)
+ assert(allLogFiles2.size >= 7)
+ }
+ ssc.stop(stopSparkContext = true, stopGracefully = true)
+
+ val sortedAllLogFiles1 = allLogFiles1.toSeq.sorted
+ val sortedAllLogFiles2 = allLogFiles2.toSeq.sorted
+ val (leftLogFiles1, leftLogFiles2) = getBothCurrentLogFiles()
+
+ printLogFiles("Receiver 0: all", sortedAllLogFiles1)
+ printLogFiles("Receiver 0: left", leftLogFiles1)
+ printLogFiles("Receiver 1: all", sortedAllLogFiles2)
+ printLogFiles("Receiver 1: left", leftLogFiles2)
+
+ // Verify that necessary latest log files are not deleted
+ // receiverStream1 needs to retain just the last batch = 1 log file
+ // receiverStream2 needs to retain 3 seconds (3-seconds window) = 3 log files
+ assert(sortedAllLogFiles1.takeRight(1).forall(leftLogFiles1.contains))
+ assert(sortedAllLogFiles2.takeRight(3).forall(leftLogFiles2.contains))
}
}
@@ -315,3 +371,42 @@ class ReceiverSuite extends FunSuite with Timeouts {
}
}
+/**
+ * An implementation of Receiver that is used for testing a receiver's life cycle.
+ */
+class FakeReceiver(sendData: Boolean = false) extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
+ @volatile var otherThread: Thread = null
+ @volatile var receiving = false
+ @volatile var onStartCalled = false
+ @volatile var onStopCalled = false
+
+ def onStart() {
+ otherThread = new Thread() {
+ override def run() {
+ receiving = true
+ var count = 0
+ while(!isStopped()) {
+ if (sendData) {
+ store(count)
+ count += 1
+ }
+ Thread.sleep(10)
+ }
+ }
+ }
+ onStartCalled = true
+ otherThread.start()
+ }
+
+ def onStop() {
+ onStopCalled = true
+ otherThread.join()
+ }
+
+ def reset() {
+ receiving = false
+ onStartCalled = false
+ onStopCalled = false
+ }
+}
+
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index 79bead77ba6e4..f96b245512271 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -19,9 +19,9 @@ package org.apache.spark.deploy.yarn
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
-import org.apache.spark.util.{Utils, IntParam, MemoryParam}
+import org.apache.spark.util.{IntParam, MemoryParam, Utils}
// TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware !
private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) {
@@ -95,6 +95,10 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
throw new IllegalArgumentException(
"You must specify at least 1 executor!\n" + getUsageMessage())
}
+ if (executorCores < sparkConf.getInt("spark.task.cpus", 1)) {
+ throw new SparkException("Executor cores must not be less than " +
+ "spark.task.cpus.")
+ }
if (isClusterMode) {
for (key <- Seq(amMemKey, amMemOverheadKey, amCoresKey)) {
if (sparkConf.contains(key)) {
@@ -222,7 +226,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
| --arg ARG Argument to be passed to your application's main class.
| Multiple invocations are possible, each will be passed in order.
| --num-executors NUM Number of executors to start (Default: 2)
- | --executor-cores NUM Number of cores for the executors (Default: 1).
+ | --executor-cores NUM Number of cores per executor (Default: 1).
| --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb)
| --driver-cores NUM Number of cores used by the driver (Default: 1).
| --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G)
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index de65ef23ad1ce..d00f29665a58f 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -17,8 +17,8 @@
package org.apache.spark.deploy.yarn
+import java.util.Collections
import java.util.concurrent._
-import java.util.concurrent.atomic.AtomicInteger
import java.util.regex.Pattern
import scala.collection.JavaConversions._
@@ -28,33 +28,26 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.yarn.api.records._
-import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse
import org.apache.hadoop.yarn.client.api.AMRMClient
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
-import org.apache.hadoop.yarn.util.Records
+import org.apache.hadoop.yarn.util.RackResolver
-import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv}
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
-import org.apache.spark.scheduler.{SplitInfo, TaskSchedulerImpl}
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
-object AllocationType extends Enumeration {
- type AllocationType = Value
- val HOST, RACK, ANY = Value
-}
-
-// TODO:
-// Too many params.
-// Needs to be mt-safe
-// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive - should
-// make it more proactive and decoupled.
-
-// Note that right now, we assume all node asks as uniform in terms of capabilities and priority
-// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for
-// more info on how we are requesting for containers.
-
/**
- * Acquires resources for executors from a ResourceManager and launches executors in new containers.
+ * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding
+ * what to do with containers when YARN fulfills these requests.
+ *
+ * This class makes use of YARN's AMRMClient APIs. We interact with the AMRMClient in three ways:
+ * * Making our resource needs known, which updates local bookkeeping about containers requested.
+ * * Calling "allocate", which syncs our local container requests with the RM, and returns any
+ * containers that YARN has granted to us. This also functions as a heartbeat.
+ * * Processing the containers granted to us to possibly launch executors inside of them.
+ *
+ * The public methods of this class are thread-safe. All methods that mutate state are
+ * synchronized.
*/
private[yarn] class YarnAllocator(
conf: Configuration,
@@ -62,50 +55,41 @@ private[yarn] class YarnAllocator(
amClient: AMRMClient[ContainerRequest],
appAttemptId: ApplicationAttemptId,
args: ApplicationMasterArguments,
- preferredNodes: collection.Map[String, collection.Set[SplitInfo]],
securityMgr: SecurityManager)
extends Logging {
import YarnAllocator._
- // These three are locked on allocatedHostToContainersMap. Complementary data structures
- // allocatedHostToContainersMap : containers which are running : host, Set
- // allocatedContainerToHostMap: container to host mapping.
- private val allocatedHostToContainersMap =
- new HashMap[String, collection.mutable.Set[ContainerId]]()
+ // Visible for testing.
+ val allocatedHostToContainersMap =
+ new HashMap[String, collection.mutable.Set[ContainerId]]
+ val allocatedContainerToHostMap = new HashMap[ContainerId, String]
- private val allocatedContainerToHostMap = new HashMap[ContainerId, String]()
+ // Containers that we no longer care about. We've either already told the RM to release them or
+ // will on the next heartbeat. Containers get removed from this map after the RM tells us they've
+ // completed.
+ private val releasedContainers = Collections.newSetFromMap[ContainerId](
+ new ConcurrentHashMap[ContainerId, java.lang.Boolean])
- // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an
- // allocated node)
- // As with the two data structures above, tightly coupled with them, and to be locked on
- // allocatedHostToContainersMap
- private val allocatedRackCount = new HashMap[String, Int]()
+ @volatile private var numExecutorsRunning = 0
+ // Used to generate a unique ID per executor
+ private var executorIdCounter = 0
+ @volatile private var numExecutorsFailed = 0
- // Containers to be released in next request to RM
- private val releasedContainers = new ConcurrentHashMap[ContainerId, Boolean]
-
- // Number of container requests that have been sent to, but not yet allocated by the
- // ApplicationMaster.
- private val numPendingAllocate = new AtomicInteger()
- private val numExecutorsRunning = new AtomicInteger()
- // Used to generate a unique id per executor
- private val executorIdCounter = new AtomicInteger()
- private val numExecutorsFailed = new AtomicInteger()
-
- private var maxExecutors = args.numExecutors
+ @volatile private var maxExecutors = args.numExecutors
// Keep track of which container is running which executor to remove the executors later
private val executorIdToContainer = new HashMap[String, Container]
+ // Executor memory in MB.
protected val executorMemory = args.executorMemory
- protected val executorCores = args.executorCores
- protected val (preferredHostToCount, preferredRackToCount) =
- generateNodeToWeight(conf, preferredNodes)
-
- // Additional memory overhead - in mb.
+ // Additional memory overhead.
protected val memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead",
math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN))
+ // Number of cores per executor.
+ protected val executorCores = args.executorCores
+ // Resource capability requested for each executors
+ private val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores)
private val launcherPool = new ThreadPoolExecutor(
// max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue
@@ -115,26 +99,34 @@ private[yarn] class YarnAllocator(
new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build())
launcherPool.allowCoreThreadTimeOut(true)
- def getNumExecutorsRunning: Int = numExecutorsRunning.intValue
+ private val driverUrl = "akka.tcp://sparkDriver@%s:%s/user/%s".format(
+ sparkConf.get("spark.driver.host"),
+ sparkConf.get("spark.driver.port"),
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
+
+ // For testing
+ private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true)
- def getNumExecutorsFailed: Int = numExecutorsFailed.intValue
+ def getNumExecutorsRunning: Int = numExecutorsRunning
+
+ def getNumExecutorsFailed: Int = numExecutorsFailed
+
+ /**
+ * Number of container requests that have not yet been fulfilled.
+ */
+ def getNumPendingAllocate: Int = getNumPendingAtLocation(ANY_HOST)
+
+ /**
+ * Number of container requests at the given location that have not yet been fulfilled.
+ */
+ private def getNumPendingAtLocation(location: String): Int =
+ amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).map(_.size).sum
/**
* Request as many executors from the ResourceManager as needed to reach the desired total.
- * This takes into account executors already running or pending.
*/
def requestTotalExecutors(requestedTotal: Int): Unit = synchronized {
- val currentTotal = numPendingAllocate.get + numExecutorsRunning.get
- if (requestedTotal > currentTotal) {
- maxExecutors += (requestedTotal - currentTotal)
- // We need to call `allocateResources` here to avoid the following race condition:
- // If we request executors twice before `allocateResources` is called, then we will end up
- // double counting the number requested because `numPendingAllocate` is not updated yet.
- allocateResources()
- } else {
- logInfo(s"Not allocating more executors because there are already $currentTotal " +
- s"(application requested $requestedTotal total)")
- }
+ maxExecutors = requestedTotal
}
/**
@@ -144,7 +136,7 @@ private[yarn] class YarnAllocator(
if (executorIdToContainer.contains(executorId)) {
val container = executorIdToContainer.remove(executorId).get
internalReleaseContainer(container)
- numExecutorsRunning.decrementAndGet()
+ numExecutorsRunning -= 1
maxExecutors -= 1
assert(maxExecutors >= 0, "Allocator killed more executors than are allocated!")
} else {
@@ -153,498 +145,234 @@ private[yarn] class YarnAllocator(
}
/**
- * Allocate missing containers based on the number of executors currently pending and running.
+ * Request resources such that, if YARN gives us all we ask for, we'll have a number of containers
+ * equal to maxExecutors.
*
- * This method prioritizes the allocated container responses from the RM based on node and
- * rack locality. Additionally, it releases any extra containers allocated for this application
- * but are not needed. This must be synchronized because variables read in this block are
- * mutated by other methods.
+ * Deal with any containers YARN has granted to us by possibly launching executors in them.
+ *
+ * This must be synchronized because variables read in this method are mutated by other methods.
*/
def allocateResources(): Unit = synchronized {
- val missing = maxExecutors - numPendingAllocate.get() - numExecutorsRunning.get()
+ val numPendingAllocate = getNumPendingAllocate
+ val missing = maxExecutors - numPendingAllocate - numExecutorsRunning
if (missing > 0) {
- val totalExecutorMemory = executorMemory + memoryOverhead
- numPendingAllocate.addAndGet(missing)
- logInfo(s"Will allocate $missing executor containers, each with $totalExecutorMemory MB " +
- s"memory including $memoryOverhead MB overhead")
- } else {
- logDebug("Empty allocation request ...")
+ logInfo(s"Will request $missing executor containers, each with ${resource.getVirtualCores} " +
+ s"cores and ${resource.getMemory} MB memory including $memoryOverhead MB overhead")
}
- val allocateResponse = allocateContainers(missing)
+ addResourceRequests(missing)
+ val progressIndicator = 0.1f
+ // Poll the ResourceManager. This doubles as a heartbeat if there are no pending container
+ // requests.
+ val allocateResponse = amClient.allocate(progressIndicator)
+
val allocatedContainers = allocateResponse.getAllocatedContainers()
if (allocatedContainers.size > 0) {
- var numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * allocatedContainers.size)
-
- if (numPendingAllocateNow < 0) {
- numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * numPendingAllocateNow)
- }
-
- logDebug("""
- Allocated containers: %d
- Current executor count: %d
- Containers released: %s
- Cluster resources: %s
- """.format(
+ logDebug("Allocated containers: %d. Current executor count: %d. Cluster resources: %s."
+ .format(
allocatedContainers.size,
- numExecutorsRunning.get(),
- releasedContainers,
+ numExecutorsRunning,
allocateResponse.getAvailableResources))
- val hostToContainers = new HashMap[String, ArrayBuffer[Container]]()
-
- for (container <- allocatedContainers) {
- if (isResourceConstraintSatisfied(container)) {
- // Add the accepted `container` to the host's list of already accepted,
- // allocated containers
- val host = container.getNodeId.getHost
- val containersForHost = hostToContainers.getOrElseUpdate(host,
- new ArrayBuffer[Container]())
- containersForHost += container
- } else {
- // Release container, since it doesn't satisfy resource constraints.
- internalReleaseContainer(container)
- }
- }
-
- // Find the appropriate containers to use.
- // TODO: Cleanup this group-by...
- val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
- val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
- val offRackContainers = new HashMap[String, ArrayBuffer[Container]]()
-
- for (candidateHost <- hostToContainers.keySet) {
- val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0)
- val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost)
-
- val remainingContainersOpt = hostToContainers.get(candidateHost)
- assert(remainingContainersOpt.isDefined)
- var remainingContainers = remainingContainersOpt.get
-
- if (requiredHostCount >= remainingContainers.size) {
- // Since we have <= required containers, add all remaining containers to
- // `dataLocalContainers`.
- dataLocalContainers.put(candidateHost, remainingContainers)
- // There are no more free containers remaining.
- remainingContainers = null
- } else if (requiredHostCount > 0) {
- // Container list has more containers than we need for data locality.
- // Split the list into two: one based on the data local container count,
- // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining
- // containers.
- val (dataLocal, remaining) = remainingContainers.splitAt(
- remainingContainers.size - requiredHostCount)
- dataLocalContainers.put(candidateHost, dataLocal)
-
- // Invariant: remainingContainers == remaining
-
- // YARN has a nasty habit of allocating a ton of containers on a host - discourage this.
- // Add each container in `remaining` to list of containers to release. If we have an
- // insufficient number of containers, then the next allocation cycle will reallocate
- // (but won't treat it as data local).
- // TODO(harvey): Rephrase this comment some more.
- for (container <- remaining) internalReleaseContainer(container)
- remainingContainers = null
- }
-
- // For rack local containers
- if (remainingContainers != null) {
- val rack = YarnSparkHadoopUtil.lookupRack(conf, candidateHost)
- if (rack != null) {
- val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0)
- val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) -
- rackLocalContainers.getOrElse(rack, List()).size
-
- if (requiredRackCount >= remainingContainers.size) {
- // Add all remaining containers to to `dataLocalContainers`.
- dataLocalContainers.put(rack, remainingContainers)
- remainingContainers = null
- } else if (requiredRackCount > 0) {
- // Container list has more containers that we need for data locality.
- // Split the list into two: one based on the data local container count,
- // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining
- // containers.
- val (rackLocal, remaining) = remainingContainers.splitAt(
- remainingContainers.size - requiredRackCount)
- val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack,
- new ArrayBuffer[Container]())
-
- existingRackLocal ++= rackLocal
-
- remainingContainers = remaining
- }
- }
- }
-
- if (remainingContainers != null) {
- // Not all containers have been consumed - add them to the list of off-rack containers.
- offRackContainers.put(candidateHost, remainingContainers)
- }
- }
-
- // Now that we have split the containers into various groups, go through them in order:
- // first host-local, then rack-local, and finally off-rack.
- // Note that the list we create below tries to ensure that not all containers end up within
- // a host if there is a sufficiently large number of hosts/containers.
- val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size)
- allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers)
- allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers)
- allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers)
-
- // Run each of the allocated containers.
- for (container <- allocatedContainersToProcess) {
- val numExecutorsRunningNow = numExecutorsRunning.incrementAndGet()
- val executorHostname = container.getNodeId.getHost
- val containerId = container.getId
-
- val executorMemoryOverhead = (executorMemory + memoryOverhead)
- assert(container.getResource.getMemory >= executorMemoryOverhead)
-
- if (numExecutorsRunningNow > maxExecutors) {
- logInfo("""Ignoring container %s at host %s, since we already have the required number of
- containers for it.""".format(containerId, executorHostname))
- internalReleaseContainer(container)
- numExecutorsRunning.decrementAndGet()
- } else {
- val executorId = executorIdCounter.incrementAndGet().toString
- val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format(
- SparkEnv.driverActorSystemName,
- sparkConf.get("spark.driver.host"),
- sparkConf.get("spark.driver.port"),
- CoarseGrainedSchedulerBackend.ACTOR_NAME)
-
- logInfo("Launching container %s for on host %s".format(containerId, executorHostname))
- executorIdToContainer(executorId) = container
-
- // To be safe, remove the container from `releasedContainers`.
- releasedContainers.remove(containerId)
-
- val rack = YarnSparkHadoopUtil.lookupRack(conf, executorHostname)
- allocatedHostToContainersMap.synchronized {
- val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname,
- new HashSet[ContainerId]())
-
- containerSet += containerId
- allocatedContainerToHostMap.put(containerId, executorHostname)
-
- if (rack != null) {
- allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1)
- }
- }
- logInfo("Launching ExecutorRunnable. driverUrl: %s, executorHostname: %s".format(
- driverUrl, executorHostname))
- val executorRunnable = new ExecutorRunnable(
- container,
- conf,
- sparkConf,
- driverUrl,
- executorId,
- executorHostname,
- executorMemory,
- executorCores,
- appAttemptId.getApplicationId.toString,
- securityMgr)
- launcherPool.execute(executorRunnable)
- }
- }
- logDebug("""
- Finished allocating %s containers (from %s originally).
- Current number of executors running: %d,
- Released containers: %s
- """.format(
- allocatedContainersToProcess,
- allocatedContainers,
- numExecutorsRunning.get(),
- releasedContainers))
+ handleAllocatedContainers(allocatedContainers)
}
val completedContainers = allocateResponse.getCompletedContainersStatuses()
if (completedContainers.size > 0) {
logDebug("Completed %d containers".format(completedContainers.size))
- for (completedContainer <- completedContainers) {
- val containerId = completedContainer.getContainerId
+ processCompletedContainers(completedContainers)
- if (releasedContainers.containsKey(containerId)) {
- // Already marked the container for release, so remove it from
- // `releasedContainers`.
- releasedContainers.remove(containerId)
- } else {
- // Decrement the number of executors running. The next iteration of
- // the ApplicationMaster's reporting thread will take care of allocating.
- numExecutorsRunning.decrementAndGet()
- logInfo("Completed container %s (state: %s, exit status: %s)".format(
- containerId,
- completedContainer.getState,
- completedContainer.getExitStatus))
- // Hadoop 2.2.X added a ContainerExitStatus we should switch to use
- // there are some exit status' we shouldn't necessarily count against us, but for
- // now I think its ok as none of the containers are expected to exit
- if (completedContainer.getExitStatus == -103) { // vmem limit exceeded
- logWarning(memLimitExceededLogMessage(
- completedContainer.getDiagnostics,
- VMEM_EXCEEDED_PATTERN))
- } else if (completedContainer.getExitStatus == -104) { // pmem limit exceeded
- logWarning(memLimitExceededLogMessage(
- completedContainer.getDiagnostics,
- PMEM_EXCEEDED_PATTERN))
- } else if (completedContainer.getExitStatus != 0) {
- logInfo("Container marked as failed: " + containerId +
- ". Exit status: " + completedContainer.getExitStatus +
- ". Diagnostics: " + completedContainer.getDiagnostics)
- numExecutorsFailed.incrementAndGet()
- }
- }
-
- allocatedHostToContainersMap.synchronized {
- if (allocatedContainerToHostMap.containsKey(containerId)) {
- val hostOpt = allocatedContainerToHostMap.get(containerId)
- assert(hostOpt.isDefined)
- val host = hostOpt.get
-
- val containerSetOpt = allocatedHostToContainersMap.get(host)
- assert(containerSetOpt.isDefined)
- val containerSet = containerSetOpt.get
-
- containerSet.remove(containerId)
- if (containerSet.isEmpty) {
- allocatedHostToContainersMap.remove(host)
- } else {
- allocatedHostToContainersMap.update(host, containerSet)
- }
-
- allocatedContainerToHostMap.remove(containerId)
-
- // TODO: Move this part outside the synchronized block?
- val rack = YarnSparkHadoopUtil.lookupRack(conf, host)
- if (rack != null) {
- val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1
- if (rackCount > 0) {
- allocatedRackCount.put(rack, rackCount)
- } else {
- allocatedRackCount.remove(rack)
- }
- }
- }
- }
- }
- logDebug("""
- Finished processing %d completed containers.
- Current number of executors running: %d,
- Released containers: %s
- """.format(
- completedContainers.size,
- numExecutorsRunning.get(),
- releasedContainers))
+ logDebug("Finished processing %d completed containers. Current running executor count: %d."
+ .format(completedContainers.size, numExecutorsRunning))
}
}
- private def allocatedContainersOnHost(host: String): Int = {
- allocatedHostToContainersMap.synchronized {
- allocatedHostToContainersMap.getOrElse(host, Set()).size
+ /**
+ * Request numExecutors additional containers from YARN. Visible for testing.
+ */
+ def addResourceRequests(numExecutors: Int): Unit = {
+ for (i <- 0 until numExecutors) {
+ val request = new ContainerRequest(resource, null, null, RM_REQUEST_PRIORITY)
+ amClient.addContainerRequest(request)
+ val nodes = request.getNodes
+ val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.last
+ logInfo("Container request (host: %s, capability: %s".format(hostStr, resource))
}
}
- private def allocatedContainersOnRack(rack: String): Int = {
- allocatedHostToContainersMap.synchronized {
- allocatedRackCount.getOrElse(rack, 0)
+ /**
+ * Handle containers granted by the RM by launching executors on them.
+ *
+ * Due to the way the YARN allocation protocol works, certain healthy race conditions can result
+ * in YARN granting containers that we no longer need. In this case, we release them.
+ *
+ * Visible for testing.
+ */
+ def handleAllocatedContainers(allocatedContainers: Seq[Container]): Unit = {
+ val containersToUse = new ArrayBuffer[Container](allocatedContainers.size)
+
+ // Match incoming requests by host
+ val remainingAfterHostMatches = new ArrayBuffer[Container]
+ for (allocatedContainer <- allocatedContainers) {
+ matchContainerToRequest(allocatedContainer, allocatedContainer.getNodeId.getHost,
+ containersToUse, remainingAfterHostMatches)
}
- }
- private def isResourceConstraintSatisfied(container: Container): Boolean = {
- container.getResource.getMemory >= (executorMemory + memoryOverhead)
- }
-
- // A simple method to copy the split info map.
- private def generateNodeToWeight(
- conf: Configuration,
- input: collection.Map[String, collection.Set[SplitInfo]])
- : (Map[String, Int], Map[String, Int]) = {
- if (input == null) {
- return (Map[String, Int](), Map[String, Int]())
+ // Match remaining by rack
+ val remainingAfterRackMatches = new ArrayBuffer[Container]
+ for (allocatedContainer <- remainingAfterHostMatches) {
+ val rack = RackResolver.resolve(conf, allocatedContainer.getNodeId.getHost).getNetworkLocation
+ matchContainerToRequest(allocatedContainer, rack, containersToUse,
+ remainingAfterRackMatches)
}
- val hostToCount = new HashMap[String, Int]
- val rackToCount = new HashMap[String, Int]
-
- for ((host, splits) <- input) {
- val hostCount = hostToCount.getOrElse(host, 0)
- hostToCount.put(host, hostCount + splits.size)
+ // Assign remaining that are neither node-local nor rack-local
+ val remainingAfterOffRackMatches = new ArrayBuffer[Container]
+ for (allocatedContainer <- remainingAfterRackMatches) {
+ matchContainerToRequest(allocatedContainer, ANY_HOST, containersToUse,
+ remainingAfterOffRackMatches)
+ }
- val rack = YarnSparkHadoopUtil.lookupRack(conf, host)
- if (rack != null) {
- val rackCount = rackToCount.getOrElse(host, 0)
- rackToCount.put(host, rackCount + splits.size)
+ if (!remainingAfterOffRackMatches.isEmpty) {
+ logDebug(s"Releasing ${remainingAfterOffRackMatches.size} unneeded containers that were " +
+ s"allocated to us")
+ for (container <- remainingAfterOffRackMatches) {
+ internalReleaseContainer(container)
}
}
- (hostToCount.toMap, rackToCount.toMap)
- }
+ runAllocatedContainers(containersToUse)
- private def internalReleaseContainer(container: Container): Unit = {
- releasedContainers.put(container.getId(), true)
- amClient.releaseAssignedContainer(container.getId())
+ logInfo("Received %d containers from YARN, launching executors on %d of them."
+ .format(allocatedContainers.size, containersToUse.size))
}
/**
- * Called to allocate containers in the cluster.
+ * Looks for requests for the given location that match the given container allocation. If it
+ * finds one, removes the request so that it won't be submitted again. Places the container into
+ * containersToUse or remaining.
*
- * @param count Number of containers to allocate.
- * If zero, should still contact RM (as a heartbeat).
- * @return Response to the allocation request.
+ * @param allocatedContainer container that was given to us by YARN
+ * @location resource name, either a node, rack, or *
+ * @param containersToUse list of containers that will be used
+ * @param remaining list of containers that will not be used
*/
- private def allocateContainers(count: Int): AllocateResponse = {
- addResourceRequests(count)
-
- // We have already set the container request. Poll the ResourceManager for a response.
- // This doubles as a heartbeat if there are no pending container requests.
- val progressIndicator = 0.1f
- amClient.allocate(progressIndicator)
+ private def matchContainerToRequest(
+ allocatedContainer: Container,
+ location: String,
+ containersToUse: ArrayBuffer[Container],
+ remaining: ArrayBuffer[Container]): Unit = {
+ val matchingRequests = amClient.getMatchingRequests(allocatedContainer.getPriority, location,
+ allocatedContainer.getResource)
+
+ // Match the allocation to a request
+ if (!matchingRequests.isEmpty) {
+ val containerRequest = matchingRequests.get(0).iterator.next
+ amClient.removeContainerRequest(containerRequest)
+ containersToUse += allocatedContainer
+ } else {
+ remaining += allocatedContainer
+ }
}
- private def createRackResourceRequests(hostContainers: ArrayBuffer[ContainerRequest])
- : ArrayBuffer[ContainerRequest] = {
- // Generate modified racks and new set of hosts under it before issuing requests.
- val rackToCounts = new HashMap[String, Int]()
-
- for (container <- hostContainers) {
- val candidateHost = container.getNodes.last
- assert(YarnSparkHadoopUtil.ANY_HOST != candidateHost)
-
- val rack = YarnSparkHadoopUtil.lookupRack(conf, candidateHost)
- if (rack != null) {
- var count = rackToCounts.getOrElse(rack, 0)
- count += 1
- rackToCounts.put(rack, count)
+ /**
+ * Launches executors in the allocated containers.
+ */
+ private def runAllocatedContainers(containersToUse: ArrayBuffer[Container]): Unit = {
+ for (container <- containersToUse) {
+ numExecutorsRunning += 1
+ assert(numExecutorsRunning <= maxExecutors)
+ val executorHostname = container.getNodeId.getHost
+ val containerId = container.getId
+ executorIdCounter += 1
+ val executorId = executorIdCounter.toString
+
+ assert(container.getResource.getMemory >= resource.getMemory)
+
+ logInfo("Launching container %s for on host %s".format(containerId, executorHostname))
+
+ val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname,
+ new HashSet[ContainerId])
+
+ containerSet += containerId
+ allocatedContainerToHostMap.put(containerId, executorHostname)
+
+ val executorRunnable = new ExecutorRunnable(
+ container,
+ conf,
+ sparkConf,
+ driverUrl,
+ executorId,
+ executorHostname,
+ executorMemory,
+ executorCores,
+ appAttemptId.getApplicationId.toString,
+ securityMgr)
+ if (launchContainers) {
+ logInfo("Launching ExecutorRunnable. driverUrl: %s, executorHostname: %s".format(
+ driverUrl, executorHostname))
+ launcherPool.execute(executorRunnable)
}
}
-
- val requestedContainers = new ArrayBuffer[ContainerRequest](rackToCounts.size)
- for ((rack, count) <- rackToCounts) {
- requestedContainers ++= createResourceRequests(
- AllocationType.RACK,
- rack,
- count,
- RM_REQUEST_PRIORITY)
- }
-
- requestedContainers
}
- private def addResourceRequests(numExecutors: Int): Unit = {
- val containerRequests: List[ContainerRequest] =
- if (numExecutors <= 0) {
- logDebug("numExecutors: " + numExecutors)
- List()
- } else if (preferredHostToCount.isEmpty) {
- logDebug("host preferences is empty")
- createResourceRequests(
- AllocationType.ANY,
- resource = null,
- numExecutors,
- RM_REQUEST_PRIORITY).toList
+ private def processCompletedContainers(completedContainers: Seq[ContainerStatus]): Unit = {
+ for (completedContainer <- completedContainers) {
+ val containerId = completedContainer.getContainerId
+
+ if (releasedContainers.contains(containerId)) {
+ // Already marked the container for release, so remove it from
+ // `releasedContainers`.
+ releasedContainers.remove(containerId)
} else {
- // Request for all hosts in preferred nodes and for numExecutors -
- // candidates.size, request by default allocation policy.
- val hostContainerRequests = new ArrayBuffer[ContainerRequest](preferredHostToCount.size)
- for ((candidateHost, candidateCount) <- preferredHostToCount) {
- val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost)
-
- if (requiredCount > 0) {
- hostContainerRequests ++= createResourceRequests(
- AllocationType.HOST,
- candidateHost,
- requiredCount,
- RM_REQUEST_PRIORITY)
- }
+ // Decrement the number of executors running. The next iteration of
+ // the ApplicationMaster's reporting thread will take care of allocating.
+ numExecutorsRunning -= 1
+ logInfo("Completed container %s (state: %s, exit status: %s)".format(
+ containerId,
+ completedContainer.getState,
+ completedContainer.getExitStatus))
+ // Hadoop 2.2.X added a ContainerExitStatus we should switch to use
+ // there are some exit status' we shouldn't necessarily count against us, but for
+ // now I think its ok as none of the containers are expected to exit
+ if (completedContainer.getExitStatus == -103) { // vmem limit exceeded
+ logWarning(memLimitExceededLogMessage(
+ completedContainer.getDiagnostics,
+ VMEM_EXCEEDED_PATTERN))
+ } else if (completedContainer.getExitStatus == -104) { // pmem limit exceeded
+ logWarning(memLimitExceededLogMessage(
+ completedContainer.getDiagnostics,
+ PMEM_EXCEEDED_PATTERN))
+ } else if (completedContainer.getExitStatus != 0) {
+ logInfo("Container marked as failed: " + containerId +
+ ". Exit status: " + completedContainer.getExitStatus +
+ ". Diagnostics: " + completedContainer.getDiagnostics)
+ numExecutorsFailed += 1
}
- val rackContainerRequests: List[ContainerRequest] = createRackResourceRequests(
- hostContainerRequests).toList
-
- val anyContainerRequests = createResourceRequests(
- AllocationType.ANY,
- resource = null,
- numExecutors,
- RM_REQUEST_PRIORITY)
-
- val containerRequestBuffer = new ArrayBuffer[ContainerRequest](
- hostContainerRequests.size + rackContainerRequests.size + anyContainerRequests.size)
-
- containerRequestBuffer ++= hostContainerRequests
- containerRequestBuffer ++= rackContainerRequests
- containerRequestBuffer ++= anyContainerRequests
- containerRequestBuffer.toList
}
- for (request <- containerRequests) {
- amClient.addContainerRequest(request)
- }
+ if (allocatedContainerToHostMap.containsKey(containerId)) {
+ val host = allocatedContainerToHostMap.get(containerId).get
+ val containerSet = allocatedHostToContainersMap.get(host).get
- for (request <- containerRequests) {
- val nodes = request.getNodes
- val hostStr = if (nodes == null || nodes.isEmpty) {
- "Any"
- } else {
- nodes.last
- }
- logInfo("Container request (host: %s, priority: %s, capability: %s".format(
- hostStr,
- request.getPriority().getPriority,
- request.getCapability))
- }
- }
+ containerSet.remove(containerId)
+ if (containerSet.isEmpty) {
+ allocatedHostToContainersMap.remove(host)
+ } else {
+ allocatedHostToContainersMap.update(host, containerSet)
+ }
- private def createResourceRequests(
- requestType: AllocationType.AllocationType,
- resource: String,
- numExecutors: Int,
- priority: Int): ArrayBuffer[ContainerRequest] = {
- // If hostname is specified, then we need at least two requests - node local and rack local.
- // There must be a third request, which is ANY. That will be specially handled.
- requestType match {
- case AllocationType.HOST => {
- assert(YarnSparkHadoopUtil.ANY_HOST != resource)
- val hostname = resource
- val nodeLocal = constructContainerRequests(
- Array(hostname),
- racks = null,
- numExecutors,
- priority)
-
- // Add `hostname` to the global (singleton) host->rack mapping in YarnAllocationHandler.
- YarnSparkHadoopUtil.populateRackInfo(conf, hostname)
- nodeLocal
- }
- case AllocationType.RACK => {
- val rack = resource
- constructContainerRequests(hosts = null, Array(rack), numExecutors, priority)
+ allocatedContainerToHostMap.remove(containerId)
}
- case AllocationType.ANY => constructContainerRequests(
- hosts = null, racks = null, numExecutors, priority)
- case _ => throw new IllegalArgumentException(
- "Unexpected/unsupported request type: " + requestType)
}
}
- private def constructContainerRequests(
- hosts: Array[String],
- racks: Array[String],
- numExecutors: Int,
- priority: Int
- ): ArrayBuffer[ContainerRequest] = {
- val memoryRequest = executorMemory + memoryOverhead
- val resource = Resource.newInstance(memoryRequest, executorCores)
-
- val prioritySetting = Records.newRecord(classOf[Priority])
- prioritySetting.setPriority(priority)
-
- val requests = new ArrayBuffer[ContainerRequest]()
- for (i <- 0 until numExecutors) {
- requests += new ContainerRequest(resource, hosts, racks, prioritySetting)
- }
- requests
+ private def internalReleaseContainer(container: Container): Unit = {
+ releasedContainers.add(container.getId())
+ amClient.releaseAssignedContainer(container.getId())
}
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
index b45e599588ad3..b134751366522 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
@@ -72,8 +72,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg
amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress)
registered = true
}
- new YarnAllocator(conf, sparkConf, amClient, getAttemptId(), args,
- preferredNodeLocations, securityMgr)
+ new YarnAllocator(conf, sparkConf, amClient, getAttemptId(), args, securityMgr)
}
/**
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
index d7cf904db1c9e..4bff846123619 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -31,7 +31,7 @@ import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.security.Credentials
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.conf.YarnConfiguration
-import org.apache.hadoop.yarn.api.records.ApplicationAccessType
+import org.apache.hadoop.yarn.api.records.{Priority, ApplicationAccessType}
import org.apache.hadoop.yarn.util.RackResolver
import org.apache.hadoop.conf.Configuration
@@ -99,13 +99,7 @@ object YarnSparkHadoopUtil {
// All RM requests are issued with same priority : we do not (yet) have any distinction between
// request types (like map/reduce in hadoop for example)
- val RM_REQUEST_PRIORITY = 1
-
- // Host to rack map - saved from allocation requests. We are expecting this not to change.
- // Note that it is possible for this to change : and ResourceManager will indicate that to us via
- // update response to allocate. But we are punting on handling that for now.
- private val hostToRack = new ConcurrentHashMap[String, String]()
- private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]()
+ val RM_REQUEST_PRIORITY = Priority.newInstance(1)
/**
* Add a path variable to the given environment map.
@@ -184,37 +178,6 @@ object YarnSparkHadoopUtil {
}
}
- def lookupRack(conf: Configuration, host: String): String = {
- if (!hostToRack.contains(host)) {
- populateRackInfo(conf, host)
- }
- hostToRack.get(host)
- }
-
- def populateRackInfo(conf: Configuration, hostname: String) {
- Utils.checkHost(hostname)
-
- if (!hostToRack.containsKey(hostname)) {
- // If there are repeated failures to resolve, all to an ignore list.
- val rackInfo = RackResolver.resolve(conf, hostname)
- if (rackInfo != null && rackInfo.getNetworkLocation != null) {
- val rack = rackInfo.getNetworkLocation
- hostToRack.put(hostname, rack)
- if (! rackToHostSet.containsKey(rack)) {
- rackToHostSet.putIfAbsent(rack,
- Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]()))
- }
- rackToHostSet.get(rack).add(hostname)
-
- // TODO(harvey): Figure out what this comment means...
- // Since RackResolver caches, we are disabling this for now ...
- } /* else {
- // right ? Else we will keep calling rack resolver in case we cant resolve rack info ...
- hostToRack.put(hostname, null)
- } */
- }
- }
-
def getApplicationAclsForYarn(securityMgr: SecurityManager)
: Map[ApplicationAccessType, String] = {
Map[ApplicationAccessType, String] (
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
index 254774a6b839e..2fa24cc43325e 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
@@ -17,8 +17,9 @@
package org.apache.spark.scheduler.cluster
+import org.apache.hadoop.yarn.util.RackResolver
+
import org.apache.spark._
-import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.util.Utils
@@ -30,6 +31,6 @@ private[spark] class YarnClientClusterScheduler(sc: SparkContext) extends TaskSc
// By default, rack is unknown
override def getRackForHost(hostPort: String): Option[String] = {
val host = Utils.parseHostPort(hostPort)._1
- Option(YarnSparkHadoopUtil.lookupRack(sc.hadoopConfiguration, host))
+ Option(RackResolver.resolve(sc.hadoopConfiguration, host).getNetworkLocation)
}
}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
index 4157ff95c2794..be55d26f1cf61 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -17,8 +17,10 @@
package org.apache.spark.scheduler.cluster
+import org.apache.hadoop.yarn.util.RackResolver
+
import org.apache.spark._
-import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnSparkHadoopUtil}
+import org.apache.spark.deploy.yarn.ApplicationMaster
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.util.Utils
@@ -39,7 +41,7 @@ private[spark] class YarnClusterScheduler(sc: SparkContext) extends TaskSchedule
// By default, rack is unknown
override def getRackForHost(hostPort: String): Option[String] = {
val host = Utils.parseHostPort(hostPort)._1
- Option(YarnSparkHadoopUtil.lookupRack(sc.hadoopConfiguration, host))
+ Option(RackResolver.resolve(sc.hadoopConfiguration, host).getNetworkLocation)
}
override def postStartHook() {
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
index 8d184a09d64cc..024b25f9d3365 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
@@ -17,18 +17,160 @@
package org.apache.spark.deploy.yarn
+import java.util.{Arrays, List => JList}
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.CommonConfigurationKeysPublic
+import org.apache.hadoop.net.DNSToSwitchMapping
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.client.api.AMRMClient
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
+
+import org.apache.spark.SecurityManager
+import org.apache.spark.SparkConf
+import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
import org.apache.spark.deploy.yarn.YarnAllocator._
-import org.scalatest.FunSuite
+import org.apache.spark.scheduler.SplitInfo
+
+import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers}
+
+class MockResolver extends DNSToSwitchMapping {
+
+ override def resolve(names: JList[String]): JList[String] = {
+ if (names.size > 0 && names.get(0) == "host3") Arrays.asList("/rack2")
+ else Arrays.asList("/rack1")
+ }
+
+ override def reloadCachedMappings() {}
+
+ def reloadCachedMappings(names: JList[String]) {}
+}
+
+class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach {
+ val conf = new Configuration()
+ conf.setClass(
+ CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY,
+ classOf[MockResolver], classOf[DNSToSwitchMapping])
+
+ val sparkConf = new SparkConf()
+ sparkConf.set("spark.driver.host", "localhost")
+ sparkConf.set("spark.driver.port", "4040")
+ sparkConf.set("spark.yarn.jar", "notarealjar.jar")
+ sparkConf.set("spark.yarn.launchContainers", "false")
+
+ val appAttemptId = ApplicationAttemptId.newInstance(ApplicationId.newInstance(0, 0), 0)
+
+ // Resource returned by YARN. YARN can give larger containers than requested, so give 6 cores
+ // instead of the 5 requested and 3 GB instead of the 2 requested.
+ val containerResource = Resource.newInstance(3072, 6)
+
+ var rmClient: AMRMClient[ContainerRequest] = _
+
+ var containerNum = 0
+
+ override def beforeEach() {
+ rmClient = AMRMClient.createAMRMClient()
+ rmClient.init(conf)
+ rmClient.start()
+ }
+
+ override def afterEach() {
+ rmClient.stop()
+ }
+
+ class MockSplitInfo(host: String) extends SplitInfo(null, host, null, 1, null) {
+ override def equals(other: Any) = false
+ }
+
+ def createAllocator(maxExecutors: Int = 5): YarnAllocator = {
+ val args = Array(
+ "--num-executors", s"$maxExecutors",
+ "--executor-cores", "5",
+ "--executor-memory", "2048",
+ "--jar", "somejar.jar",
+ "--class", "SomeClass")
+ new YarnAllocator(
+ conf,
+ sparkConf,
+ rmClient,
+ appAttemptId,
+ new ApplicationMasterArguments(args),
+ new SecurityManager(sparkConf))
+ }
+
+ def createContainer(host: String): Container = {
+ val containerId = ContainerId.newInstance(appAttemptId, containerNum)
+ containerNum += 1
+ val nodeId = NodeId.newInstance(host, 1000)
+ Container.newInstance(containerId, nodeId, "", containerResource, RM_REQUEST_PRIORITY, null)
+ }
+
+ test("single container allocated") {
+ // request a single container and receive it
+ val handler = createAllocator()
+ handler.addResourceRequests(1)
+ handler.getNumExecutorsRunning should be (0)
+ handler.getNumPendingAllocate should be (1)
+
+ val container = createContainer("host1")
+ handler.handleAllocatedContainers(Array(container))
+
+ handler.getNumExecutorsRunning should be (1)
+ handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1")
+ handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId)
+ rmClient.getMatchingRequests(container.getPriority, "host1", containerResource).size should be (0)
+ }
+
+ test("some containers allocated") {
+ // request a few containers and receive some of them
+ val handler = createAllocator()
+ handler.addResourceRequests(4)
+ handler.getNumExecutorsRunning should be (0)
+ handler.getNumPendingAllocate should be (4)
+
+ val container1 = createContainer("host1")
+ val container2 = createContainer("host1")
+ val container3 = createContainer("host2")
+ handler.handleAllocatedContainers(Array(container1, container2, container3))
+
+ handler.getNumExecutorsRunning should be (3)
+ handler.allocatedContainerToHostMap.get(container1.getId).get should be ("host1")
+ handler.allocatedContainerToHostMap.get(container2.getId).get should be ("host1")
+ handler.allocatedContainerToHostMap.get(container3.getId).get should be ("host2")
+ handler.allocatedHostToContainersMap.get("host1").get should contain (container1.getId)
+ handler.allocatedHostToContainersMap.get("host1").get should contain (container2.getId)
+ handler.allocatedHostToContainersMap.get("host2").get should contain (container3.getId)
+ }
+
+ test("receive more containers than requested") {
+ val handler = createAllocator(2)
+ handler.addResourceRequests(2)
+ handler.getNumExecutorsRunning should be (0)
+ handler.getNumPendingAllocate should be (2)
+
+ val container1 = createContainer("host1")
+ val container2 = createContainer("host2")
+ val container3 = createContainer("host4")
+ handler.handleAllocatedContainers(Array(container1, container2, container3))
+
+ handler.getNumExecutorsRunning should be (2)
+ handler.allocatedContainerToHostMap.get(container1.getId).get should be ("host1")
+ handler.allocatedContainerToHostMap.get(container2.getId).get should be ("host2")
+ handler.allocatedContainerToHostMap.contains(container3.getId) should be (false)
+ handler.allocatedHostToContainersMap.get("host1").get should contain (container1.getId)
+ handler.allocatedHostToContainersMap.get("host2").get should contain (container2.getId)
+ handler.allocatedHostToContainersMap.contains("host4") should be (false)
+ }
-class YarnAllocatorSuite extends FunSuite {
test("memory exceeded diagnostic regexes") {
val diagnostics =
"Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " +
- "beyond physical memory limits. Current usage: 2.1 MB of 2 GB physical memory used; " +
- "5.8 GB of 4.2 GB virtual memory used. Killing container."
+ "beyond physical memory limits. Current usage: 2.1 MB of 2 GB physical memory used; " +
+ "5.8 GB of 4.2 GB virtual memory used. Killing container."
val vmemMsg = memLimitExceededLogMessage(diagnostics, VMEM_EXCEEDED_PATTERN)
val pmemMsg = memLimitExceededLogMessage(diagnostics, PMEM_EXCEEDED_PATTERN)
assert(vmemMsg.contains("5.8 GB of 4.2 GB virtual memory used."))
assert(pmemMsg.contains("2.1 MB of 2 GB physical memory used."))
}
+
}
|