diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 89eec7d4b7f61..c99a61f63ea2b 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -10,3 +10,4 @@ log4j.logger.org.eclipse.jetty=WARN log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO +log4j.logger.org.apache.hadoop.yarn.util.RackResolver=WARN diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index f02b035a980b1..a1f7133f897ee 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -121,6 +121,14 @@ pre { border: none; } +.description-input { + overflow: hidden; + text-overflow: ellipsis; + width: 100%; + white-space: nowrap; + display: block; +} + .stacktrace-details { max-height: 300px; overflow-y: auto; diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index a0ee2a7cbb2a2..b28da192c1c0d 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -158,7 +158,7 @@ private[spark] class ExecutorAllocationManager( "shuffle service. You may enable this through spark.shuffle.service.enabled.") } if (tasksPerExecutor == 0) { - throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.cores") + throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.") } } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index a0ce107f43b16..f9d4aa4240e9d 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -18,6 +18,7 @@ package org.apache.spark import scala.collection.JavaConverters._ +import scala.collection.concurrent.TrieMap import scala.collection.mutable.{HashMap, LinkedHashSet} import org.apache.spark.serializer.KryoSerializer @@ -46,7 +47,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Create a SparkConf that loads defaults from system properties and the classpath */ def this() = this(true) - private[spark] val settings = new HashMap[String, String]() + private[spark] val settings = new TrieMap[String, String]() if (loadDefaults) { // Load any spark.* system properties @@ -177,7 +178,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } /** Get all parameters as a list of pairs */ - def getAll: Array[(String, String)] = settings.clone().toArray + def getAll: Array[(String, String)] = settings.toArray /** Get a parameter as an integer, falling back to a default if not set */ def getInt(key: String, defaultValue: Int): Int = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index a1dfb01062591..33a7aae5d3fcd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -168,7 +168,7 @@ private[spark] class TaskSchedulerImpl( if (!hasLaunchedTask) { logWarning("Initial job has not accepted any resources; " + "check your cluster UI to ensure that workers are registered " + - "and have sufficient memory") + "and have sufficient resources") } else { this.cancel() } diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 662a7b91248aa..fa8a337ad63a8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -92,7 +92,7 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade } override def deserializeStream(s: InputStream): DeserializationStream = { - new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader) + new JavaDeserializationStream(s, defaultClassLoader) } def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = { diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index b4677447c8872..fc1844600f1cb 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -22,20 +22,23 @@ import scala.util.Random import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.scheduler.SchedulingMode +// scalastyle:off /** * Continuously generates jobs that expose various features of the WebUI (internal testing tool). * - * Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR] + * Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR] [#job set (4 jobs per set)] */ +// scalastyle:on private[spark] object UIWorkloadGenerator { val NUM_PARTITIONS = 100 val INTER_JOB_WAIT_MS = 5000 def main(args: Array[String]) { - if (args.length < 2) { + if (args.length < 3) { println( - "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]") + "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " + + "[master] [FIFO|FAIR] [#job set (4 jobs per set)]") System.exit(1) } @@ -45,6 +48,7 @@ private[spark] object UIWorkloadGenerator { if (schedulingMode == SchedulingMode.FAIR) { conf.set("spark.scheduler.mode", "FAIR") } + val nJobSet = args(2).toInt val sc = new SparkContext(conf) def setProperties(s: String) = { @@ -84,7 +88,7 @@ private[spark] object UIWorkloadGenerator { ("Job with delays", baseData.map(x => Thread.sleep(100)).count) ) - while (true) { + (1 to nJobSet).foreach { _ => for ((desc, job) <- jobs) { new Thread { override def run() { @@ -101,5 +105,6 @@ private[spark] object UIWorkloadGenerator { Thread.sleep(INTER_JOB_WAIT_MS) } } + sc.stop() } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 81212708ba524..045c69da06feb 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -64,7 +64,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { {job.jobId} {job.jobGroup.map(id => s"($id)").getOrElse("")} -
{lastStageDescription}
+ {lastStageDescription} {lastStageName} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 1da7a988203db..479f967fb1541 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -59,54 +59,86 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]) val poolTable = new PoolTable(pools, parent) + val shouldShowActiveStages = activeStages.nonEmpty + val shouldShowPendingStages = pendingStages.nonEmpty + val shouldShowCompletedStages = completedStages.nonEmpty + val shouldShowFailedStages = failedStages.nonEmpty + val summary: NodeSeq =
- val content = summary ++ - {if (sc.isDefined && isFairScheduler) { -

{pools.size} Fair Scheduler Pools

++ poolTable.toNodeSeq - } else { - Seq[Node]() - }} ++ -

Active Stages ({activeStages.size})

++ - activeStagesTable.toNodeSeq ++ -

Pending Stages ({pendingStages.size})

++ - pendingStagesTable.toNodeSeq ++ -

Completed Stages ({numCompletedStages})

++ - completedStagesTable.toNodeSeq ++ -

Failed Stages ({numFailedStages})

++ + var content = summary ++ + { + if (sc.isDefined && isFairScheduler) { +

{pools.size} Fair Scheduler Pools

++ poolTable.toNodeSeq + } else { + Seq[Node]() + } + } + if (shouldShowActiveStages) { + content ++=

Active Stages ({activeStages.size})

++ + activeStagesTable.toNodeSeq + } + if (shouldShowPendingStages) { + content ++=

Pending Stages ({pendingStages.size}

++ + pendingStagesTable.toNodeSeq + } + if (shouldShowCompletedStages) { + content ++=

Completed Stages ({numCompletedStages})

++ + completedStagesTable.toNodeSeq + } + if (shouldShowFailedStages) { + content ++=

Failed Stages ({numFailedStages})

++ failedStagesTable.toNodeSeq - + } UIUtils.headerSparkPage("Spark Stages (for all jobs)", content, parent) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index e7d6244dcd679..703d43f9c640d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -112,9 +112,8 @@ private[ui] class StageTableBase( stageData <- listener.stageIdToData.get((s.stageId, s.attemptId)) desc <- stageData.description } yield { -
{desc}
+ {desc} } -
{stageDesc.getOrElse("")} {killLink} {nameLink} {details}
} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala index 948c350953e27..de58be38c7bfb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala @@ -54,7 +54,7 @@ object DenseGmmEM { for (i <- 0 until clusters.k) { println("weight=%f\nmu=%s\nsigma=\n%s\n" format - (clusters.weight(i), clusters.mu(i), clusters.sigma(i))) + (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma)) } println("Cluster labels (first <= 100):") diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 8a13c74221546..2d6a825b61726 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -133,6 +133,12 @@ object GraphGenerators { // This ensures that the 4 quadrants are the same size at all recursion levels val numVertices = math.round( math.pow(2.0, math.ceil(math.log(requestedNumVertices) / math.log(2.0)))).toInt + val numEdgesUpperBound = + math.pow(2.0, 2 * ((math.log(numVertices) / math.log(2.0)) - 1)).toInt + if (numEdgesUpperBound < numEdges) { + throw new IllegalArgumentException( + s"numEdges must be <= $numEdgesUpperBound but was $numEdges") + } var edges: Set[Edge[Int]] = Set() while (edges.size < numEdges) { if (edges.size % 100 == 0) { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala index 3abefbe52fa8a..8d9c8ddccbb3c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala @@ -110,4 +110,14 @@ class GraphGeneratorsSuite extends FunSuite with LocalSparkContext { } } + test("SPARK-5064 GraphGenerators.rmatGraph numEdges upper bound") { + withSpark { sc => + val g1 = GraphGenerators.rmatGraph(sc, 4, 4) + assert(g1.edges.count() === 4) + intercept[IllegalArgumentException] { + val g2 = GraphGenerators.rmatGraph(sc, 4, 8) + } + } + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 555da8c7e7ab3..430d763ef7ca7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -266,12 +266,16 @@ class PythonMLLibAPI extends Serializable { k: Int, maxIterations: Int, runs: Int, - initializationMode: String): KMeansModel = { + initializationMode: String, + seed: java.lang.Long): KMeansModel = { val kMeansAlg = new KMeans() .setK(k) .setMaxIterations(maxIterations) .setRuns(runs) .setInitializationMode(initializationMode) + + if (seed != null) kMeansAlg.setSeed(seed) + try { kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) } finally { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala index d8e134619411b..899fe5e9e9cf2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala @@ -134,9 +134,7 @@ class GaussianMixtureEM private ( // diagonal covariance matrices using component variances // derived from the samples val (weights, gaussians) = initialModel match { - case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) => - new MultivariateGaussian(mu, sigma) - }) + case Some(gmm) => (gmm.weights, gmm.gaussians) case None => { val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) @@ -176,10 +174,7 @@ class GaussianMixtureEM private ( iter += 1 } - // Need to convert the breeze matrices to MLlib matrices - val means = Array.tabulate(k) { i => gaussians(i).mu } - val sigmas = Array.tabulate(k) { i => gaussians(i).sigma } - new GaussianMixtureModel(weights, means, sigmas) + new GaussianMixtureModel(weights, gaussians) } /** Average of dense breeze vectors */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 416cad080c408..1a2178ee7f711 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BreezeVector} import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils @@ -36,12 +36,13 @@ import org.apache.spark.mllib.util.MLUtils * covariance matrix for Gaussian i */ class GaussianMixtureModel( - val weight: Array[Double], - val mu: Array[Vector], - val sigma: Array[Matrix]) extends Serializable { + val weights: Array[Double], + val gaussians: Array[MultivariateGaussian]) extends Serializable { + + require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") /** Number of gaussians in mixture */ - def k: Int = weight.length + def k: Int = weights.length /** Maps given points to their cluster indices. */ def predict(points: RDD[Vector]): RDD[Int] = { @@ -55,14 +56,10 @@ class GaussianMixtureModel( */ def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext - val dists = sc.broadcast { - (0 until k).map { i => - new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix) - }.toArray - } - val weights = sc.broadcast(weight) + val bcDists = sc.broadcast(gaussians) + val bcWeights = sc.broadcast(weights) points.map { x => - computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k) + computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 54c301d3e9e14..11633e8242313 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -19,14 +19,14 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer -import org.apache.spark.annotation.Experimental import org.apache.spark.Logging -import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom /** @@ -43,13 +43,14 @@ class KMeans private ( private var runs: Int, private var initializationMode: String, private var initializationSteps: Int, - private var epsilon: Double) extends Serializable with Logging { + private var epsilon: Double, + private var seed: Long) extends Serializable with Logging { /** * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1, - * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4}. + * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}. */ - def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4) + def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong()) /** Set the number of clusters to create (k). Default: 2. */ def setK(k: Int): this.type = { @@ -112,6 +113,12 @@ class KMeans private ( this } + /** Set the random seed for cluster initialization. */ + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. @@ -255,7 +262,7 @@ class KMeans private ( private def initRandom(data: RDD[VectorWithNorm]) : Array[Array[VectorWithNorm]] = { // Sample all the cluster centers in one pass to avoid repeated scans - val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq + val sample = data.takeSample(true, runs * k, new XORShiftRandom(this.seed).nextInt()).toSeq Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v => new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm) }.toArray) @@ -272,45 +279,81 @@ class KMeans private ( */ private def initKMeansParallel(data: RDD[VectorWithNorm]) : Array[Array[VectorWithNorm]] = { - // Initialize each run's center to a random point - val seed = new XORShiftRandom().nextInt() + // Initialize empty centers and point costs. + val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm]) + var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache() + + // Initialize each run's first center to a random point. + val seed = new XORShiftRandom(this.seed).nextInt() val sample = data.takeSample(true, runs, seed).toSeq - val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) + val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) + + /** Merges new centers to centers. */ + def mergeNewCenters(): Unit = { + var r = 0 + while (r < runs) { + centers(r) ++= newCenters(r) + newCenters(r).clear() + r += 1 + } + } // On each step, sample 2 * k points on average for each run with probability proportional - // to their squared distance from that run's current centers + // to their squared distance from that run's centers. Note that only distances between points + // and new centers are computed in each iteration. var step = 0 while (step < initializationSteps) { - val bcCenters = data.context.broadcast(centers) - val sumCosts = data.flatMap { point => - (0 until runs).map { r => - (r, KMeans.pointCost(bcCenters.value(r), point)) - } - }.reduceByKey(_ + _).collectAsMap() - val chosen = data.mapPartitionsWithIndex { (index, points) => + val bcNewCenters = data.context.broadcast(newCenters) + val preCosts = costs + costs = data.zip(preCosts).map { case (point, cost) => + Vectors.dense( + Array.tabulate(runs) { r => + math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r)) + }) + }.cache() + val sumCosts = costs + .aggregate(Vectors.zeros(runs))( + seqOp = (s, v) => { + // s += v + axpy(1.0, v, s) + s + }, + combOp = (s0, s1) => { + // s0 += s1 + axpy(1.0, s1, s0) + s0 + } + ) + preCosts.unpersist(blocking = false) + val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) - points.flatMap { p => - (0 until runs).filter { r => - rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r) - }.map((_, p)) + pointsWithCosts.flatMap { case (p, c) => + val rs = (0 until runs).filter { r => + rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) + } + if (rs.length > 0) Some(p, rs) else None } }.collect() - chosen.foreach { case (r, p) => - centers(r) += p.toDense + mergeNewCenters() + chosen.foreach { case (p, rs) => + rs.foreach(newCenters(_) += p.toDense) } step += 1 } + mergeNewCenters() + costs.unpersist(blocking = false) + // Finally, we might have a set of more than k candidate centers for each run; weigh each // candidate by the number of points in the dataset mapping to it and run a local k-means++ // on the weighted centers to pick just k of them val bcCenters = data.context.broadcast(centers) val weightMap = data.flatMap { p => - (0 until runs).map { r => + Iterator.tabulate(runs) { r => ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0) } }.reduceByKey(_ + _).collectAsMap() - val finalCenters = (0 until runs).map { r => + val finalCenters = (0 until runs).par.map { r => val myCenters = centers(r).toArray val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30) @@ -333,7 +376,32 @@ object KMeans { /** * Trains a k-means model using the given set of parameters. * - * @param data training points stored as `RDD[Array[Double]]` + * @param data training points stored as `RDD[Vector]` + * @param k number of clusters + * @param maxIterations max number of iterations + * @param runs number of parallel runs, defaults to 1. The best model is returned. + * @param initializationMode initialization model, either "random" or "k-means||" (default). + * @param seed random seed value for cluster initialization + */ + def train( + data: RDD[Vector], + k: Int, + maxIterations: Int, + runs: Int, + initializationMode: String, + seed: Long): KMeansModel = { + new KMeans().setK(k) + .setMaxIterations(maxIterations) + .setRuns(runs) + .setInitializationMode(initializationMode) + .setSeed(seed) + .run(data) + } + + /** + * Trains a k-means model using the given set of parameters. + * + * @param data training points stored as `RDD[Vector]` * @param k number of clusters * @param maxIterations max number of iterations * @param runs number of parallel runs, defaults to 1. The best model is returned. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index adbd8266ed6fa..7ee0224ad4662 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -50,13 +50,35 @@ sealed trait Vector extends Serializable { override def equals(other: Any): Boolean = { other match { - case v: Vector => - util.Arrays.equals(this.toArray, v.toArray) + case v2: Vector => { + if (this.size != v2.size) return false + (this, v2) match { + case (s1: SparseVector, s2: SparseVector) => + Vectors.equals(s1.indices, s1.values, s2.indices, s2.values) + case (s1: SparseVector, d1: DenseVector) => + Vectors.equals(s1.indices, s1.values, 0 until d1.size, d1.values) + case (d1: DenseVector, s1: SparseVector) => + Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values) + case (_, _) => util.Arrays.equals(this.toArray, v2.toArray) + } + } case _ => false } } - override def hashCode(): Int = util.Arrays.hashCode(this.toArray) + override def hashCode(): Int = { + var result: Int = size + 31 + this.foreachActive { case (index, value) => + // ignore explict 0 for comparison between sparse and dense + if (value != 0) { + result = 31 * result + index + // refer to {@link java.util.Arrays.equals} for hash algorithm + val bits = java.lang.Double.doubleToLongBits(value) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } + return result + } /** * Converts the instance to a breeze vector. @@ -392,6 +414,33 @@ object Vectors { } squaredDistance } + + /** + * Check equality between sparse/dense vectors + */ + private[mllib] def equals( + v1Indices: IndexedSeq[Int], + v1Values: Array[Double], + v2Indices: IndexedSeq[Int], + v2Values: Array[Double]): Boolean = { + val v1Size = v1Values.size + val v2Size = v2Values.size + var k1 = 0 + var k2 = 0 + var allEqual = true + while (allEqual) { + while (k1 < v1Size && v1Values(k1) == 0) k1 += 1 + while (k2 < v2Size && v2Values(k2) == 0) k2 += 1 + + if (k1 >= v1Size || k2 >= v2Size) { + return k1 >= v1Size && k2 >= v2Size // check end alignment + } + allEqual = v1Indices(k1) == v2Indices(k2) && v1Values(k1) == v2Values(k2) + k1 += 1 + k2 += 1 + } + allEqual + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 06d8915f3bfa1..b60559c853a50 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -69,6 +69,11 @@ class CoordinateMatrix( nRows } + /** Transposes this CoordinateMatrix. */ + def transpose(): CoordinateMatrix = { + new CoordinateMatrix(entries.map(x => MatrixEntry(x.j, x.i, x.value)), numCols(), numRows()) + } + /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ def toIndexedRowMatrix(): IndexedRowMatrix = { val nl = numCols() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 181f507516485..c518271f04729 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -75,6 +75,23 @@ class IndexedRowMatrix( new RowMatrix(rows.map(_.vector), 0L, nCols) } + /** + * Converts this matrix to a + * [[org.apache.spark.mllib.linalg.distributed.CoordinateMatrix]]. + */ + def toCoordinateMatrix(): CoordinateMatrix = { + val entries = rows.flatMap { row => + val rowIndex = row.index + row.vector match { + case SparseVector(size, indices, values) => + Iterator.tabulate(indices.size)(i => MatrixEntry(rowIndex, indices(i), values(i))) + case DenseVector(values) => + Iterator.tabulate(values.size)(i => MatrixEntry(rowIndex, i, values(i))) + } + } + new CoordinateMatrix(entries, numRows(), numCols()) + } + /** * Computes the singular value decomposition of this IndexedRowMatrix. * Denote this matrix by A (m x n), this will compute matrices U, S, V such that A = U * S * V'. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index cf51d041c65a9..ed8e6a796f8c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -68,6 +68,15 @@ case class BoostingStrategy( @Experimental object BoostingStrategy { + /** + * Returns default configuration for the boosting algorithm + * @param algo Learning goal. Supported: "Classification" or "Regression" + * @return Configuration for boosting algorithm + */ + def defaultParams(algo: String): BoostingStrategy = { + defaultParams(Algo.fromString(algo)) + } + /** * Returns default configuration for the boosting algorithm * @param algo Learning goal. Supported: @@ -75,15 +84,15 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ - def defaultParams(algo: String): BoostingStrategy = { - val treeStrategy = Strategy.defaultStrategy(algo) - treeStrategy.maxDepth = 3 + def defaultParams(algo: Algo): BoostingStrategy = { + val treeStragtegy = Strategy.defaultStategy(algo) + treeStragtegy.maxDepth = 3 algo match { - case "Classification" => - treeStrategy.numClasses = 2 - new BoostingStrategy(treeStrategy, LogLoss) - case "Regression" => - new BoostingStrategy(treeStrategy, SquaredError) + case Algo.Classification => + treeStragtegy.numClasses = 2 + new BoostingStrategy(treeStragtegy, LogLoss) + case Algo.Regression => + new BoostingStrategy(treeStragtegy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index d5cd89ab94e81..972959885f396 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -173,11 +173,19 @@ object Strategy { * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algo "Classification" or "Regression" */ - def defaultStrategy(algo: String): Strategy = algo match { - case "Classification" => + def defaultStrategy(algo: String): Strategy = { + defaultStategy(Algo.fromString(algo)) + } + + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo Algo.Classification or Algo.Regression + */ + def defaultStategy(algo: Algo): Strategy = algo match { + case Algo.Classification => new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, numClasses = 2) - case "Regression" => + case Algo.Regression => new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, numClasses = 0) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala index 9da5495741a80..198997b5bb2b2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{Vectors, Matrices} +import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -39,9 +40,9 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex val seeds = Array(314589, 29032897, 50181, 494821, 4660) seeds.foreach { seed => val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data) - assert(gmm.weight(0) ~== Ew absTol 1E-5) - assert(gmm.mu(0) ~== Emu absTol 1E-5) - assert(gmm.sigma(0) ~== Esigma absTol 1E-5) + assert(gmm.weights(0) ~== Ew absTol 1E-5) + assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5) + assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5) } } @@ -57,8 +58,10 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex // we set an initial gaussian to induce expected results val initialGmm = new GaussianMixtureModel( Array(0.5, 0.5), - Array(Vectors.dense(-1.0), Vectors.dense(1.0)), - Array(Matrices.dense(1, 1, Array(1.0)), Matrices.dense(1, 1, Array(1.0))) + Array( + new MultivariateGaussian(Vectors.dense(-1.0), Matrices.dense(1, 1, Array(1.0))), + new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0))) + ) ) val Ew = Array(1.0 / 3.0, 2.0 / 3.0) @@ -70,11 +73,11 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex .setInitialModel(initialGmm) .run(data) - assert(gmm.weight(0) ~== Ew(0) absTol 1E-3) - assert(gmm.weight(1) ~== Ew(1) absTol 1E-3) - assert(gmm.mu(0) ~== Emu(0) absTol 1E-3) - assert(gmm.mu(1) ~== Emu(1) absTol 1E-3) - assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3) - assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3) + assert(gmm.weights(0) ~== Ew(0) absTol 1E-3) + assert(gmm.weights(1) ~== Ew(1) absTol 1E-3) + assert(gmm.gaussians(0).mu ~== Emu(0) absTol 1E-3) + assert(gmm.gaussians(1).mu ~== Emu(1) absTol 1E-3) + assert(gmm.gaussians(0).sigma ~== Esigma(0) absTol 1E-3) + assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 9ebef8466c831..caee5917000aa 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -90,6 +90,27 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { assert(model.clusterCenters.size === 3) } + test("deterministic initialization") { + // Create a large-ish set of points for clustering + val points = List.tabulate(1000)(n => Vectors.dense(n, n)) + val rdd = sc.parallelize(points, 3) + + for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) { + // Create three deterministic models and compare cluster means + val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, + initializationMode = initMode, seed = 42) + val centers1 = model1.clusterCenters + + val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, + initializationMode = initMode, seed = 42) + val centers2 = model2.clusterCenters + + centers1.zip(centers2).foreach { case (c1, c2) => + assert(c1 ~== c2 absTol 1E-14) + } + } + } + test("single cluster with big dataset") { val smallData = Array( Vectors.dense(1.0, 2.0, 6.0), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 85ac8ccebfc59..5def899cea117 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -89,6 +89,24 @@ class VectorsSuite extends FunSuite { } } + test("vectors equals with explicit 0") { + val dv1 = Vectors.dense(Array(0, 0.9, 0, 0.8, 0)) + val sv1 = Vectors.sparse(5, Array(1, 3), Array(0.9, 0.8)) + val sv2 = Vectors.sparse(5, Array(0, 1, 2, 3, 4), Array(0, 0.9, 0, 0.8, 0)) + + val vectors = Seq(dv1, sv1, sv2) + for (v <- vectors; u <- vectors) { + assert(v === u) + assert(v.## === u.##) + } + + val another = Vectors.sparse(5, Array(0, 1, 3), Array(0, 0.9, 0.2)) + for (v <- vectors) { + assert(v != another) + assert(v.## != another.##) + } + } + test("indexing dense vectors") { val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0) assert(vec(0) === 1.0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index f8709751efce6..80bef814ce50d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -73,6 +73,11 @@ class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(mat.toBreeze() === expected) } + test("transpose") { + val transposed = mat.transpose() + assert(mat.toBreeze().t === transposed.toBreeze()) + } + test("toIndexedRowMatrix") { val indexedRowMatrix = mat.toIndexedRowMatrix() val expected = BDM( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 741cd4997b853..b86c2ca5ff136 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -80,6 +80,14 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(rowMat.rows.collect().toSeq === data.map(_.vector).toSeq) } + test("toCoordinateMatrix") { + val idxRowMat = new IndexedRowMatrix(indexedRows) + val coordMat = idxRowMat.toCoordinateMatrix() + assert(coordMat.numRows() === m) + assert(coordMat.numCols() === n) + assert(coordMat.toBreeze() === idxRowMat.toBreeze()) + } + test("multiply a local matrix") { val A = new IndexedRowMatrix(indexedRows) val B = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) diff --git a/pom.xml b/pom.xml index f4466e56c2a53..b993391b15042 100644 --- a/pom.xml +++ b/pom.xml @@ -1507,7 +1507,7 @@ mapr3 1.0.3-mapr-3.0.3 - 2.3.0-mapr-4.0.0-FCS + 2.4.1-mapr-1408 0.94.17-mapr-1405 3.4.5-mapr-1406 @@ -1516,8 +1516,8 @@ mapr4 - 2.3.0-mapr-4.0.0-FCS - 2.3.0-mapr-4.0.0-FCS + 2.4.1-mapr-1408 + 2.4.1-mapr-1408 0.94.17-mapr-1405-4.0.0-FCS 3.4.5-mapr-1406 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 95fef23ee4f39..127973b658190 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -86,6 +86,10 @@ object MimaExcludes { // SPARK-5270 ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.isEmpty") + ) ++ Seq( + // SPARK-5297 Java FileStream do not work with custom key/values + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream") ) case v if v.startsWith("1.2") => diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index e2492eef5bd6a..6b713aa39374e 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -78,10 +78,10 @@ def predict(self, x): class KMeans(object): @classmethod - def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"): + def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None): """Train a k-means clustering model.""" model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations, - runs, initializationMode) + runs, initializationMode, seed) centers = callJavaFunc(rdd.context, model.clusterCenters) return KMeansModel([c.toArray() for c in centers]) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 140c22b5fd4e8..f48e3d6dacb4b 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -140,7 +140,7 @@ class ListTests(PySparkTestCase): as NumPy arrays. """ - def test_clustering(self): + def test_kmeans(self): from pyspark.mllib.clustering import KMeans data = [ [0, 1.1], @@ -152,6 +152,21 @@ def test_clustering(self): self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + def test_kmeans_deterministic(self): + from pyspark.mllib.clustering import KMeans + X = range(0, 100, 10) + Y = range(0, 100, 10) + data = [[x, y] for x, y in zip(X, Y)] + clusters1 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", seed=42) + clusters2 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", seed=42) + centers1 = clusters1.centers + centers2 = clusters2.centers + for c1, c2 in zip(centers1, centers2): + # TODO: Allow small numeric difference. + self.assertTrue(array_equal(c1, c2)) + def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.tree import DecisionTree diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 208ec92987ac8..41bb4f012f2e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import scala.util.hashing.MurmurHash3 + import org.apache.spark.sql.catalyst.expressions.GenericRow @@ -32,7 +34,7 @@ object Row { * } * }}} */ - def unapplySeq(row: Row): Some[Seq[Any]] = Some(row) + def unapplySeq(row: Row): Some[Seq[Any]] = Some(row.toSeq) /** * This method can be used to construct a [[Row]] with the given values. @@ -43,6 +45,16 @@ object Row { * This method can be used to construct a [[Row]] from a [[Seq]] of values. */ def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray) + + def fromTuple(tuple: Product): Row = fromSeq(tuple.productIterator.toSeq) + + /** + * Merge multiple rows into a single row, one after another. + */ + def merge(rows: Row*): Row = { + // TODO: Improve the performance of this if used in performance critical part. + new GenericRow(rows.flatMap(_.toSeq).toArray) + } } @@ -103,7 +115,13 @@ object Row { * * @group row */ -trait Row extends Seq[Any] with Serializable { +trait Row extends Serializable { + /** Number of elements in the Row. */ + def size: Int = length + + /** Number of elements in the Row. */ + def length: Int + /** * Returns the value at position i. If the value is null, null is returned. The following * is a mapping between Spark SQL types and return types: @@ -291,12 +309,61 @@ trait Row extends Seq[Any] with Serializable { /** Returns true if there are any NULL values in this row. */ def anyNull: Boolean = { - val l = length + val len = length var i = 0 - while (i < l) { + while (i < len) { if (isNullAt(i)) { return true } i += 1 } false } + + override def equals(that: Any): Boolean = that match { + case null => false + case that: Row => + if (this.length != that.length) { + return false + } + var i = 0 + val len = this.length + while (i < len) { + if (apply(i) != that.apply(i)) { + return false + } + i += 1 + } + true + case _ => false + } + + override def hashCode: Int = { + // Using Scala's Seq hash code implementation. + var n = 0 + var h = MurmurHash3.seqSeed + val len = length + while (n < len) { + h = MurmurHash3.mix(h, apply(n).##) + n += 1 + } + MurmurHash3.finalizeHash(h, n) + } + + /* ---------------------- utility methods for Scala ---------------------- */ + + /** + * Return a Scala Seq representing the row. ELements are placed in the same order in the Seq. + */ + def toSeq: Seq[Any] + + /** Displays all elements of this sequence in a string (without a separator). */ + def mkString: String = toSeq.mkString + + /** Displays all elements of this sequence in a string using a separator string. */ + def mkString(sep: String): String = toSeq.mkString(sep) + + /** + * Displays all elements of this traversable or iterator in a string using + * start, end, and separator strings. + */ + def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 93d74adbcc957..366be00473d1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -25,15 +25,42 @@ import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.plans.logical._ +private[sql] object KeywordNormalizer { + def apply(str: String) = str.toLowerCase() +} + private[sql] abstract class AbstractSparkSQLParser extends StandardTokenParsers with PackratParsers { - def apply(input: String): LogicalPlan = phrase(start)(new lexical.Scanner(input)) match { - case Success(plan, _) => plan - case failureOrError => sys.error(failureOrError.toString) + def apply(input: String): LogicalPlan = { + // Initialize the Keywords. + lexical.initialize(reservedWords) + phrase(start)(new lexical.Scanner(input)) match { + case Success(plan, _) => plan + case failureOrError => sys.error(failureOrError.toString) + } } - protected case class Keyword(str: String) + protected case class Keyword(str: String) { + def normalize = KeywordNormalizer(str) + def parser: Parser[String] = normalize + } + + protected implicit def asParser(k: Keyword): Parser[String] = k.parser + + // By default, use Reflection to find the reserved words defined in the sub class. + // NOTICE, Since the Keyword properties defined by sub class, we couldn't call this + // method during the parent class instantiation, because the sub class instance + // isn't created yet. + protected lazy val reservedWords: Seq[String] = + this + .getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].normalize) + + // Set the keywords as empty by default, will change that later. + override val lexical = new SqlLexical protected def start: Parser[LogicalPlan] @@ -52,18 +79,27 @@ private[sql] abstract class AbstractSparkSQLParser } } -class SqlLexical(val keywords: Seq[String]) extends StdLexical { +class SqlLexical extends StdLexical { case class FloatLit(chars: String) extends Token { override def toString = chars } - reserved ++= keywords.flatMap(w => allCaseVersions(w)) + /* This is a work around to support the lazy setting */ + def initialize(keywords: Seq[String]): Unit = { + reserved.clear() + reserved ++= keywords + } delimiters += ( "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>" ) + protected override def processIdent(name: String) = { + val token = KeywordNormalizer(name) + if (reserved contains token) Keyword(token) else Identifier(name) + } + override lazy val token: Parser[Token] = ( identChar ~ (identChar | digit).* ^^ { case first ~ rest => processIdent((first :: rest).mkString) } @@ -94,14 +130,5 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { | '-' ~ '-' ~ chrExcept(EofCh, '\n').* | '/' ~ '*' ~ failure("unclosed comment") ).* - - /** Generate all variations of upper and lower case of a given string */ - def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { - if (s.isEmpty) { - Stream(prefix) - } else { - allCaseVersions(s.tail, prefix + s.head.toLower) #::: - allCaseVersions(s.tail, prefix + s.head.toUpper) - } - } } + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d280db83b26f7..191d16fb10b5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -84,8 +84,9 @@ trait ScalaReflection { } def convertRowToScala(r: Row, schema: StructType): Row = { + // TODO: This is very slow!!! new GenericRow( - r.zip(schema.fields.map(_.dataType)) + r.toSeq.zip(schema.fields.map(_.dataType)) .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 0b36d8b9bfce5..eaadbe9fd5099 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -36,9 +36,8 @@ import org.apache.spark.sql.types._ * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ class SqlParser extends AbstractSparkSQLParser { - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` + // properties via reflection the class in runtime for constructing the SqlLexical object protected val ABS = Keyword("ABS") protected val ALL = Keyword("ALL") protected val AND = Keyword("AND") @@ -107,16 +106,6 @@ class SqlParser extends AbstractSparkSQLParser { protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") - // Use reflection to find the reserved words defined in this class. - protected val reservedWords = - this - .getClass - .getMethods - .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].str) - - override val lexical = new SqlLexical(reservedWords) - protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { exprs.zipWithIndex.map { case (ne: NamedExpression, _) => ne diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 26c855878d202..417659eed5957 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -272,9 +272,6 @@ package object dsl { def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) = Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) - def sfilter(dynamicUdf: (DynamicRow) => Boolean) = - Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan) - def sample( fraction: Double, withReplacement: Boolean = true, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 1a2133bbbcec7..ece5ee73618cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -407,7 +407,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val casts = from.fields.zip(to.fields).map { case (fromField, toField) => cast(fromField.dataType, toField.dataType) } - buildCast[Row](_, row => Row(row.zip(casts).map { + // TODO: This is very slow! + buildCast[Row](_, row => Row(row.toSeq.zip(casts).map { case (v, cast) => if (v == null) null else cast(v) }: _*)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index e7e81a21fdf03..db5d897ee569f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -105,45 +105,45 @@ class JoinedRow extends Row { this } - def iterator = row1.iterator ++ row2.iterator + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - def length = row1.length + row2.length + override def length = row1.length + row2.length - def apply(i: Int) = - if (i < row1.size) row1(i) else row2(i - row1.size) + override def apply(i: Int) = + if (i < row1.length) row1(i) else row2(i - row1.length) - def isNullAt(i: Int) = - if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + override def isNullAt(i: Int) = + if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - def getInt(i: Int): Int = - if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + override def getInt(i: Int): Int = + if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - def getLong(i: Int): Long = - if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + override def getLong(i: Int): Long = + if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - def getDouble(i: Int): Double = - if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + override def getDouble(i: Int): Double = + if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - def getBoolean(i: Int): Boolean = - if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + override def getBoolean(i: Int): Boolean = + if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - def getShort(i: Int): Short = - if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + override def getShort(i: Int): Short = + if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - def getByte(i: Int): Byte = - if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + override def getByte(i: Int): Byte = + if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - def getFloat(i: Int): Float = - if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + override def getFloat(i: Int): Float = + if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - def getString(i: Int): String = - if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getString(i: Int): String = + if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) override def getAs[T](i: Int): T = - if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - def copy() = { - val totalSize = row1.size + row2.size + override def copy() = { + val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { @@ -154,8 +154,16 @@ class JoinedRow extends Row { } override def toString() = { - val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) - s"[${row.mkString(",")}]" + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.mkString("[", ",", "]") + } else if (row2 eq null) { + row1.mkString("[", ",", "]") + } else { + mkString("[", ",", "]") + } } } @@ -197,45 +205,45 @@ class JoinedRow2 extends Row { this } - def iterator = row1.iterator ++ row2.iterator + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - def length = row1.length + row2.length + override def length = row1.length + row2.length - def apply(i: Int) = - if (i < row1.size) row1(i) else row2(i - row1.size) + override def apply(i: Int) = + if (i < row1.length) row1(i) else row2(i - row1.length) - def isNullAt(i: Int) = - if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + override def isNullAt(i: Int) = + if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - def getInt(i: Int): Int = - if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + override def getInt(i: Int): Int = + if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - def getLong(i: Int): Long = - if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + override def getLong(i: Int): Long = + if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - def getDouble(i: Int): Double = - if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + override def getDouble(i: Int): Double = + if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - def getBoolean(i: Int): Boolean = - if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + override def getBoolean(i: Int): Boolean = + if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - def getShort(i: Int): Short = - if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + override def getShort(i: Int): Short = + if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - def getByte(i: Int): Byte = - if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + override def getByte(i: Int): Byte = + if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - def getFloat(i: Int): Float = - if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + override def getFloat(i: Int): Float = + if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - def getString(i: Int): String = - if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getString(i: Int): String = + if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) override def getAs[T](i: Int): T = - if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - def copy() = { - val totalSize = row1.size + row2.size + override def copy() = { + val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { @@ -246,8 +254,16 @@ class JoinedRow2 extends Row { } override def toString() = { - val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) - s"[${row.mkString(",")}]" + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.mkString("[", ",", "]") + } else if (row2 eq null) { + row1.mkString("[", ",", "]") + } else { + mkString("[", ",", "]") + } } } @@ -283,45 +299,45 @@ class JoinedRow3 extends Row { this } - def iterator = row1.iterator ++ row2.iterator + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - def length = row1.length + row2.length + override def length = row1.length + row2.length - def apply(i: Int) = - if (i < row1.size) row1(i) else row2(i - row1.size) + override def apply(i: Int) = + if (i < row1.length) row1(i) else row2(i - row1.length) - def isNullAt(i: Int) = - if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + override def isNullAt(i: Int) = + if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - def getInt(i: Int): Int = - if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + override def getInt(i: Int): Int = + if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - def getLong(i: Int): Long = - if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + override def getLong(i: Int): Long = + if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - def getDouble(i: Int): Double = - if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + override def getDouble(i: Int): Double = + if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - def getBoolean(i: Int): Boolean = - if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + override def getBoolean(i: Int): Boolean = + if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - def getShort(i: Int): Short = - if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + override def getShort(i: Int): Short = + if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - def getByte(i: Int): Byte = - if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + override def getByte(i: Int): Byte = + if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - def getFloat(i: Int): Float = - if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + override def getFloat(i: Int): Float = + if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - def getString(i: Int): String = - if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getString(i: Int): String = + if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) override def getAs[T](i: Int): T = - if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - def copy() = { - val totalSize = row1.size + row2.size + override def copy() = { + val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { @@ -332,8 +348,16 @@ class JoinedRow3 extends Row { } override def toString() = { - val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) - s"[${row.mkString(",")}]" + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.mkString("[", ",", "]") + } else if (row2 eq null) { + row1.mkString("[", ",", "]") + } else { + mkString("[", ",", "]") + } } } @@ -369,45 +393,45 @@ class JoinedRow4 extends Row { this } - def iterator = row1.iterator ++ row2.iterator + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - def length = row1.length + row2.length + override def length = row1.length + row2.length - def apply(i: Int) = - if (i < row1.size) row1(i) else row2(i - row1.size) + override def apply(i: Int) = + if (i < row1.length) row1(i) else row2(i - row1.length) - def isNullAt(i: Int) = - if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + override def isNullAt(i: Int) = + if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - def getInt(i: Int): Int = - if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + override def getInt(i: Int): Int = + if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - def getLong(i: Int): Long = - if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + override def getLong(i: Int): Long = + if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - def getDouble(i: Int): Double = - if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + override def getDouble(i: Int): Double = + if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - def getBoolean(i: Int): Boolean = - if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + override def getBoolean(i: Int): Boolean = + if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - def getShort(i: Int): Short = - if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + override def getShort(i: Int): Short = + if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - def getByte(i: Int): Byte = - if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + override def getByte(i: Int): Byte = + if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - def getFloat(i: Int): Float = - if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + override def getFloat(i: Int): Float = + if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - def getString(i: Int): String = - if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getString(i: Int): String = + if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) override def getAs[T](i: Int): T = - if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - def copy() = { - val totalSize = row1.size + row2.size + override def copy() = { + val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { @@ -418,8 +442,16 @@ class JoinedRow4 extends Row { } override def toString() = { - val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) - s"[${row.mkString(",")}]" + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.mkString("[", ",", "]") + } else if (row2 eq null) { + row1.mkString("[", ",", "]") + } else { + mkString("[", ",", "]") + } } } @@ -455,45 +487,45 @@ class JoinedRow5 extends Row { this } - def iterator = row1.iterator ++ row2.iterator + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - def length = row1.length + row2.length + override def length = row1.length + row2.length - def apply(i: Int) = - if (i < row1.size) row1(i) else row2(i - row1.size) + override def apply(i: Int) = + if (i < row1.length) row1(i) else row2(i - row1.length) - def isNullAt(i: Int) = - if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + override def isNullAt(i: Int) = + if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - def getInt(i: Int): Int = - if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + override def getInt(i: Int): Int = + if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - def getLong(i: Int): Long = - if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + override def getLong(i: Int): Long = + if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - def getDouble(i: Int): Double = - if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + override def getDouble(i: Int): Double = + if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - def getBoolean(i: Int): Boolean = - if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + override def getBoolean(i: Int): Boolean = + if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - def getShort(i: Int): Short = - if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + override def getShort(i: Int): Short = + if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - def getByte(i: Int): Byte = - if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + override def getByte(i: Int): Byte = + if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - def getFloat(i: Int): Float = - if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + override def getFloat(i: Int): Float = + if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - def getString(i: Int): String = - if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getString(i: Int): String = + if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) override def getAs[T](i: Int): T = - if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - def copy() = { - val totalSize = row1.size + row2.size + override def copy() = { + val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { @@ -504,7 +536,15 @@ class JoinedRow5 extends Row { } override def toString() = { - val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) - s"[${row.mkString(",")}]" + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.mkString("[", ",", "]") + } else if (row2 eq null) { + row1.mkString("[", ",", "]") + } else { + mkString("[", ",", "]") + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 37d9f0ed5c79e..7434165f654f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -209,6 +209,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def length: Int = values.length + override def toSeq: Seq[Any] = values.map(_.boxed).toSeq + override def setNullAt(i: Int): Unit = { values(i).isNull = true } @@ -231,8 +233,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR if (value == null) setNullAt(ordinal) else values(ordinal).update(value) } - override def iterator: Iterator[Any] = values.map(_.boxed).iterator - override def setString(ordinal: Int, value: String) = update(ordinal, value) override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala deleted file mode 100644 index 8328278544a1e..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import scala.language.dynamics - -import org.apache.spark.sql.types.DataType - -/** - * The data type representing [[DynamicRow]] values. - */ -case object DynamicType extends DataType - -/** - * Wrap a [[Row]] as a [[DynamicRow]]. - */ -case class WrapDynamic(children: Seq[Attribute]) extends Expression { - type EvaluatedType = DynamicRow - - def nullable = false - - def dataType = DynamicType - - override def eval(input: Row): DynamicRow = input match { - // Avoid copy for generic rows. - case g: GenericRow => new DynamicRow(children, g.values) - case otherRowType => new DynamicRow(children, otherRowType.toArray) - } -} - -/** - * DynamicRows use scala's Dynamic trait to emulate an ORM of in a dynamically typed language. - * Since the type of the column is not known at compile time, all attributes are converted to - * strings before being passed to the function. - */ -class DynamicRow(val schema: Seq[Attribute], values: Array[Any]) - extends GenericRow(values) with Dynamic { - - def selectDynamic(attributeName: String): String = { - val ordinal = schema.indexWhere(_.name == attributeName) - values(ordinal).toString - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index cc97cb4f50b69..69397a73a8880 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -77,14 +77,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { """.children : Seq[Tree] } - val iteratorFunction = { - val allColumns = (0 until expressions.size).map { i => - val iLit = ru.Literal(Constant(i)) - q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" - } - q"override def iterator = Iterator[Any](..$allColumns)" - } - val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)""" val applyFunction = { val cases = (0 until expressions.size).map { i => @@ -191,20 +183,26 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } """ + val allColumns = (0 until expressions.size).map { i => + val iLit = ru.Literal(Constant(i)) + q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" + } + val copyFunction = - q""" - override def copy() = new $genericRowType(this.toArray) - """ + q"override def copy() = new $genericRowType(Array[Any](..$allColumns))" + + val toSeqFunction = + q"override def toSeq: Seq[Any] = Seq(..$allColumns)" val classBody = nullFunctions ++ ( lengthDef +: - iteratorFunction +: applyFunction +: updateFunction +: equalsFunction +: hashCodeFunction +: copyFunction +: + toSeqFunction +: (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions)) val code = q""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index c22b8426841da..8df150e2f855f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -44,7 +44,7 @@ trait MutableRow extends Row { */ object EmptyRow extends Row { override def apply(i: Int): Any = throw new UnsupportedOperationException - override def iterator = Iterator.empty + override def toSeq = Seq.empty override def length = 0 override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException override def getInt(i: Int): Int = throw new UnsupportedOperationException @@ -70,7 +70,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { def this(size: Int) = this(new Array[Any](size)) - override def iterator = values.iterator + override def toSeq = values.toSeq override def length = values.length @@ -119,7 +119,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } // Custom hashCode function that matches the efficient code generated version. - override def hashCode(): Int = { + override def hashCode: Int = { var result: Int = 37 var i = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 1483beacc9088..9628e93274a11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -238,16 +238,11 @@ case class Rollup( case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output = child.output - override lazy val statistics: Statistics = - if (output.forall(_.dataType.isInstanceOf[NativeType])) { - val limit = limitExpr.eval(null).asInstanceOf[Int] - val sizeInBytes = (limit: Long) * output.map { a => - NativeType.defaultSizeOf(a.dataType.asInstanceOf[NativeType]) - }.sum - Statistics(sizeInBytes = sizeInBytes) - } else { - Statistics(sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product) - } + override lazy val statistics: Statistics = { + val limit = limitExpr.eval(null).asInstanceOf[Int] + val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum + Statistics(sizeInBytes = sizeInBytes) + } } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index bcd74603d4013..9f30f40a173e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -215,6 +215,9 @@ abstract class DataType { case _ => false } + /** The default size of a value of this data type. */ + def defaultSize: Int + def isPrimitive: Boolean = false def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase @@ -235,33 +238,25 @@ abstract class DataType { * @group dataType */ @DeveloperApi -case object NullType extends DataType +case object NullType extends DataType { + override def defaultSize: Int = 1 +} -object NativeType { +protected[sql] object NativeType { val all = Seq( IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) def unapply(dt: DataType): Boolean = all.contains(dt) - - val defaultSizeOf: Map[NativeType, Int] = Map( - IntegerType -> 4, - BooleanType -> 1, - LongType -> 8, - DoubleType -> 8, - FloatType -> 4, - ShortType -> 2, - ByteType -> 1, - StringType -> 4096) } -trait PrimitiveType extends DataType { +protected[sql] trait PrimitiveType extends DataType { override def isPrimitive = true } -object PrimitiveType { +protected[sql] object PrimitiveType { private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap @@ -276,7 +271,7 @@ object PrimitiveType { } } -abstract class NativeType extends DataType { +protected[sql] abstract class NativeType extends DataType { private[sql] type JvmType @transient private[sql] val tag: TypeTag[JvmType] private[sql] val ordering: Ordering[JvmType] @@ -300,6 +295,11 @@ case object StringType extends NativeType with PrimitiveType { private[sql] type JvmType = String @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the StringType is 4096 bytes. + */ + override def defaultSize: Int = 4096 } @@ -324,6 +324,11 @@ case object BinaryType extends NativeType with PrimitiveType { x.length - y.length } } + + /** + * The default size of a value of the BinaryType is 4096 bytes. + */ + override def defaultSize: Int = 4096 } @@ -339,6 +344,11 @@ case object BooleanType extends NativeType with PrimitiveType { private[sql] type JvmType = Boolean @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the BooleanType is 1 byte. + */ + override def defaultSize: Int = 1 } @@ -359,6 +369,11 @@ case object TimestampType extends NativeType { private[sql] val ordering = new Ordering[JvmType] { def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) } + + /** + * The default size of a value of the TimestampType is 8 bytes. + */ + override def defaultSize: Int = 8 } @@ -379,10 +394,15 @@ case object DateType extends NativeType { private[sql] val ordering = new Ordering[JvmType] { def compare(x: Date, y: Date) = x.compareTo(y) } + + /** + * The default size of a value of the DateType is 8 bytes. + */ + override def defaultSize: Int = 8 } -abstract class NumericType extends NativeType with PrimitiveType { +protected[sql] abstract class NumericType extends NativeType with PrimitiveType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets @@ -392,13 +412,13 @@ abstract class NumericType extends NativeType with PrimitiveType { } -object NumericType { +protected[sql] object NumericType { def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] } /** Matcher for any expressions that evaluate to [[IntegralType]]s */ -object IntegralType { +protected[sql] object IntegralType { def unapply(a: Expression): Boolean = a match { case e: Expression if e.dataType.isInstanceOf[IntegralType] => true case _ => false @@ -406,7 +426,7 @@ object IntegralType { } -sealed abstract class IntegralType extends NumericType { +protected[sql] sealed abstract class IntegralType extends NumericType { private[sql] val integral: Integral[JvmType] } @@ -425,6 +445,11 @@ case object LongType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Long]] private[sql] val integral = implicitly[Integral[Long]] private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the LongType is 8 bytes. + */ + override def defaultSize: Int = 8 } @@ -442,6 +467,11 @@ case object IntegerType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Int]] private[sql] val integral = implicitly[Integral[Int]] private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the IntegerType is 4 bytes. + */ + override def defaultSize: Int = 4 } @@ -459,6 +489,11 @@ case object ShortType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Short]] private[sql] val integral = implicitly[Integral[Short]] private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the ShortType is 2 bytes. + */ + override def defaultSize: Int = 2 } @@ -476,11 +511,16 @@ case object ByteType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Byte]] private[sql] val integral = implicitly[Integral[Byte]] private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the ByteType is 1 byte. + */ + override def defaultSize: Int = 1 } /** Matcher for any expressions that evaluate to [[FractionalType]]s */ -object FractionalType { +protected[sql] object FractionalType { def unapply(a: Expression): Boolean = a match { case e: Expression if e.dataType.isInstanceOf[FractionalType] => true case _ => false @@ -488,7 +528,7 @@ object FractionalType { } -sealed abstract class FractionalType extends NumericType { +protected[sql] sealed abstract class FractionalType extends NumericType { private[sql] val fractional: Fractional[JvmType] private[sql] val asIntegral: Integral[JvmType] } @@ -530,6 +570,11 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" case None => "DecimalType()" } + + /** + * The default size of a value of the DecimalType is 4096 bytes. + */ + override def defaultSize: Int = 4096 } @@ -580,6 +625,11 @@ case object DoubleType extends FractionalType { private[sql] val fractional = implicitly[Fractional[Double]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = DoubleAsIfIntegral + + /** + * The default size of a value of the DoubleType is 8 bytes. + */ + override def defaultSize: Int = 8 } @@ -598,6 +648,11 @@ case object FloatType extends FractionalType { private[sql] val fractional = implicitly[Fractional[Float]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = FloatAsIfIntegral + + /** + * The default size of a value of the FloatType is 4 bytes. + */ + override def defaultSize: Int = 4 } @@ -636,6 +691,12 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT ("type" -> typeName) ~ ("elementType" -> elementType.jsonValue) ~ ("containsNull" -> containsNull) + + /** + * The default size of a value of the ArrayType is 100 * the default size of the element type. + * (We assume that there are 100 elements). + */ + override def defaultSize: Int = 100 * elementType.defaultSize } @@ -805,6 +866,11 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override def length: Int = fields.length override def iterator: Iterator[StructField] = fields.iterator + + /** + * The default size of a value of the StructType is the total default sizes of all field types. + */ + override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum } @@ -848,6 +914,13 @@ case class MapType( ("keyType" -> keyType.jsonValue) ~ ("valueType" -> valueType.jsonValue) ~ ("valueContainsNull" -> valueContainsNull) + + /** + * The default size of a value of the MapType is + * 100 * (the default size of the key type + the default size of the value type). + * (We assume that there are 100 elements). + */ + override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize) } @@ -896,4 +969,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { * Class object for the UserType */ def userClass: java.lang.Class[UserType] + + /** + * The default size of a value of the UserDefinedType is 4096 bytes. + */ + override def defaultSize: Int = 4096 } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 6df5db4c80f34..5138942a55daa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -244,7 +244,7 @@ class ScalaReflectionSuite extends FunSuite { test("convert PrimitiveData to catalyst") { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) - val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) + val convertedData = Row(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) val dataType = schemaFor[PrimitiveData].dataType assert(convertToCatalyst(data, dataType) === convertedData) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala new file mode 100644 index 0000000000000..1a0a0e6154ad2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.Command +import org.scalatest.FunSuite + +private[sql] case class TestCommand(cmd: String) extends Command + +private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { + protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") + + override protected lazy val start: Parser[LogicalPlan] = set + + private lazy val set: Parser[LogicalPlan] = + EXECUTE ~> ident ^^ { + case fileName => TestCommand(fileName) + } +} + +private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { + protected val EXECUTE = Keyword("EXECUTE") + + override protected lazy val start: Parser[LogicalPlan] = set + + private lazy val set: Parser[LogicalPlan] = + EXECUTE ~> ident ^^ { + case fileName => TestCommand(fileName) + } +} + +class SqlParserSuite extends FunSuite { + + test("test long keyword") { + val parser = new SuperLongKeywordTestParser + assert(TestCommand("NotRealCommand") === parser("ThisIsASuperLongKeyWordTest NotRealCommand")) + } + + test("test case insensitive") { + val parser = new CaseInsensitiveTestParser + assert(TestCommand("NotRealCommand") === parser("EXECUTE NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser("execute NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser("exEcute NotRealCommand")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 892195f46ea24..c147be9f6b1ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -62,6 +62,7 @@ class DataTypeSuite extends FunSuite { } } + checkDataTypeJsonRepr(NullType) checkDataTypeJsonRepr(BooleanType) checkDataTypeJsonRepr(ByteType) checkDataTypeJsonRepr(ShortType) @@ -69,7 +70,9 @@ class DataTypeSuite extends FunSuite { checkDataTypeJsonRepr(LongType) checkDataTypeJsonRepr(FloatType) checkDataTypeJsonRepr(DoubleType) + checkDataTypeJsonRepr(DecimalType(10, 5)) checkDataTypeJsonRepr(DecimalType.Unlimited) + checkDataTypeJsonRepr(DateType) checkDataTypeJsonRepr(TimestampType) checkDataTypeJsonRepr(StringType) checkDataTypeJsonRepr(BinaryType) @@ -77,12 +80,39 @@ class DataTypeSuite extends FunSuite { checkDataTypeJsonRepr(ArrayType(StringType, false)) checkDataTypeJsonRepr(MapType(IntegerType, StringType, true)) checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false)) + val metadata = new MetadataBuilder() .putString("name", "age") .build() - checkDataTypeJsonRepr( - StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", ArrayType(DoubleType), nullable = false), - StructField("c", DoubleType, nullable = false, metadata)))) + val structType = StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", ArrayType(DoubleType), nullable = false), + StructField("c", DoubleType, nullable = false, metadata))) + checkDataTypeJsonRepr(structType) + + def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = { + test(s"Check the default size of ${dataType}") { + assert(dataType.defaultSize === expectedDefaultSize) + } + } + + checkDefaultSize(NullType, 1) + checkDefaultSize(BooleanType, 1) + checkDefaultSize(ByteType, 1) + checkDefaultSize(ShortType, 2) + checkDefaultSize(IntegerType, 4) + checkDefaultSize(LongType, 8) + checkDefaultSize(FloatType, 4) + checkDefaultSize(DoubleType, 8) + checkDefaultSize(DecimalType(10, 5), 4096) + checkDefaultSize(DecimalType.Unlimited, 4096) + checkDefaultSize(DateType, 8) + checkDefaultSize(TimestampType, 8) + checkDefaultSize(StringType, 4096) + checkDefaultSize(BinaryType, 4096) + checkDefaultSize(ArrayType(DoubleType, true), 800) + checkDefaultSize(ArrayType(StringType, false), 409600) + checkDefaultSize(MapType(IntegerType, StringType, true), 410000) + checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400) + checkDefaultSize(structType, 812) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f23cb18c92d5d..0a22968cc7807 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -107,7 +107,7 @@ class SQLContext(@transient val sparkContext: SparkContext) } protected[sql] def parseSql(sql: String): LogicalPlan = { - ddlParser(sql).getOrElse(sqlParser(sql)) + ddlParser(sql, false).getOrElse(sqlParser(sql)) } protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index ae4d8ba90c5bd..d1e21dffeb8c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -330,25 +330,6 @@ class SchemaRDD( sqlContext, Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)) - /** - * :: Experimental :: - * Filters tuples using a function over a `Dynamic` version of a given Row. DynamicRows use - * scala's Dynamic trait to emulate an ORM of in a dynamically typed language. Since the type of - * the column is not known at compile time, all attributes are converted to strings before - * being passed to the function. - * - * {{{ - * schemaRDD.where(r => r.firstName == "Bob" && r.lastName == "Smith") - * }}} - * - * @group Query - */ - @Experimental - def where(dynamicUdf: (DynamicRow) => Boolean) = - new SchemaRDD( - sqlContext, - Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan)) - /** * :: Experimental :: * Returns a sampled version of the underlying dataset. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala index f10ee7b66feb7..f1a4053b79113 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql + import scala.util.parsing.combinator.RegexParsers -import org.apache.spark.sql.catalyst.{SqlLexical, AbstractSparkSQLParser} +import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{UncacheTableCommand, CacheTableCommand, SetCommand} @@ -61,18 +62,6 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr protected val TABLE = Keyword("TABLE") protected val UNCACHE = Keyword("UNCACHE") - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - - private val reservedWords: Seq[String] = - this - .getClass - .getMethods - .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].str) - - override val lexical = new SqlLexical(reservedWords) - override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | others private lazy val cache: Parser[LogicalPlan] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 065fae3c83df1..11d5943fb427f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -21,7 +21,6 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -128,8 +127,7 @@ private[sql] case class InMemoryRelation( rowCount += 1 } - val stats = Row.fromSeq( - columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _)) + val stats = Row.merge(columnBuilders.map(_.columnStats.collectedStatistics) : _*) batchStats += stats CachedBatch(columnBuilders.map(_.build().array()), stats) @@ -271,9 +269,10 @@ private[sql] case class InMemoryColumnarTableScan( // Extract rows via column accessors new Iterator[Row] { + private[this] val rowLen = nextRow.length override def next() = { var i = 0 - while (i < nextRow.length) { + while (i < rowLen) { columnAccessors(i).extractTo(nextRow, i) i += 1 } @@ -297,7 +296,7 @@ private[sql] case class InMemoryColumnarTableScan( cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { def statsString = relation.partitionStatistics.schema - .zip(cachedBatch.stats) + .zip(cachedBatch.stats.toSeq) .map { case (a, s) => s"${a.name}: $s" } .mkString(", ") logInfo(s"Skipping partition based on stats $statsString") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 64673248394c6..68a5b1de7691b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -127,7 +127,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { while (from.hasRemaining) { columnType.extract(from, value, 0) - if (value.head == currentValue.head) { + if (value(0) == currentValue(0)) { currentRun += 1 } else { // Writes current run diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 46245cd5a1869..4d7e338e8ed13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -144,7 +144,7 @@ package object debug { case (null, _) => case (row: Row, StructType(fields)) => - row.zip(fields.map(_.dataType)).foreach { case(d,t) => typeCheck(d,t) } + row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } case (s: Seq[_], ArrayType(elemType, _)) => s.foreach(typeCheck(_, elemType)) case (m: Map[_, _], MapType(keyType, valueType, _)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 7ed64aad10d4e..b85021acc9d4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -116,9 +116,9 @@ object EvaluatePython { def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { case (null, _) => null - case (row: Seq[Any], struct: StructType) => + case (row: Row, struct: StructType) => val fields = struct.fields.map(field => field.dataType) - row.zip(fields).map { + row.toSeq.zip(fields).map { case (obj, dataType) => toJava(obj, dataType) }.toArray @@ -143,7 +143,8 @@ object EvaluatePython { * Convert Row into Java Array (for pickled into Python) */ def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = { - row.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray + // TODO: this is slow! + row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray } // Converts value to the type specified by the data type. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index db70a7eac72b9..9171939f7e8f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -458,16 +458,16 @@ private[sql] object JsonRDD extends Logging { gen.writeEndArray() case (MapType(kv,vv, _), v: Map[_,_]) => - gen.writeStartObject + gen.writeStartObject() v.foreach { p => gen.writeFieldName(p._1.toString) valWriter(vv,p._2) } - gen.writeEndObject + gen.writeEndObject() - case (StructType(ty), v: Seq[_]) => + case (StructType(ty), v: Row) => gen.writeStartObject() - ty.zip(v).foreach { + ty.zip(v.toSeq).foreach { case (_, null) => case (field, v) => gen.writeFieldName(field.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index b4aed04199129..9d9150246c8d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -66,7 +66,7 @@ private[sql] object CatalystConverter { // TODO: consider using Array[T] for arrays to avoid boxing of primitive types type ArrayScalaType[T] = Seq[T] - type StructScalaType[T] = Seq[T] + type StructScalaType[T] = Row type MapScalaType[K, V] = Map[K, V] protected[parquet] def createConverter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 381298caba6f2..171b816a26332 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -18,32 +18,32 @@ package org.apache.spark.sql.sources import scala.language.implicitConversions -import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.PackratParsers import org.apache.spark.Logging import org.apache.spark.sql.{SchemaRDD, SQLContext} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.SqlLexical +import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types._ import org.apache.spark.util.Utils + /** * A parser for foreign DDL commands. */ -private[sql] class DDLParser extends StandardTokenParsers with PackratParsers with Logging { - - def apply(input: String): Option[LogicalPlan] = { - phrase(ddl)(new lexical.Scanner(input)) match { - case Success(r, x) => Some(r) - case x => - logDebug(s"Not recognized as DDL: $x") - None +private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { + + def apply(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = { + try { + Some(apply(input)) + } catch { + case _ if !exceptionOnError => None + case x: Throwable => throw x } } def parseType(input: String): DataType = { + lexical.initialize(reservedWords) phrase(dataType)(new lexical.Scanner(input)) match { case Success(r, x) => r case x => @@ -51,11 +51,9 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi } } - protected case class Keyword(str: String) - - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` + // properties via reflection the class in runtime for constructing the SqlLexical object protected val CREATE = Keyword("CREATE") protected val TEMPORARY = Keyword("TEMPORARY") protected val TABLE = Keyword("TABLE") @@ -80,17 +78,10 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected val MAP = Keyword("MAP") protected val STRUCT = Keyword("STRUCT") - // Use reflection to find the reserved words defined in this class. - protected val reservedWords = - this.getClass - .getMethods - .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].str) - - override val lexical = new SqlLexical(reservedWords) - protected lazy val ddl: Parser[LogicalPlan] = createTable + protected def start: Parser[LogicalPlan] = ddl + /** * `CREATE [TEMPORARY] TABLE avroTable * USING org.apache.spark.sql.avro diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 2bcfe28456997..afbfe214f1ce4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -45,28 +45,28 @@ class DslQuerySuite extends QueryTest { test("agg") { checkAnswer( testData2.groupBy('a)('a, sum('b)), - Seq((1,3),(2,3),(3,3)) + Seq(Row(1,3), Row(2,3), Row(3,3)) ) checkAnswer( testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)), - 9 + Row(9) ) checkAnswer( testData2.aggregate(sum('b)), - 9 + Row(9) ) } test("convert $\"attribute name\" into unresolved attribute") { checkAnswer( testData.where($"key" === 1).select($"value"), - Seq(Seq("1"))) + Row("1")) } test("convert Scala Symbol 'attrname into unresolved attribute") { checkAnswer( testData.where('key === 1).select('value), - Seq(Seq("1"))) + Row("1")) } test("select *") { @@ -78,61 +78,61 @@ class DslQuerySuite extends QueryTest { test("simple select") { checkAnswer( testData.where('key === 1).select('value), - Seq(Seq("1"))) + Row("1")) } test("select with functions") { checkAnswer( testData.select(sum('value), avg('value), count(1)), - Seq(Seq(5050.0, 50.5, 100))) + Row(5050.0, 50.5, 100)) checkAnswer( testData2.select('a + 'b, 'a < 'b), Seq( - Seq(2, false), - Seq(3, true), - Seq(3, false), - Seq(4, false), - Seq(4, false), - Seq(5, false))) + Row(2, false), + Row(3, true), + Row(3, false), + Row(4, false), + Row(4, false), + Row(5, false))) checkAnswer( testData2.select(sumDistinct('a)), - Seq(Seq(6))) + Row(6)) } test("global sorting") { checkAnswer( testData2.orderBy('a.asc, 'b.asc), - Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) + Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) checkAnswer( testData2.orderBy('a.asc, 'b.desc), - Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1))) + Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) checkAnswer( testData2.orderBy('a.desc, 'b.desc), - Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1))) + Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) checkAnswer( testData2.orderBy('a.desc, 'b.asc), - Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) + Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) checkAnswer( arrayData.orderBy('data.getItem(0).asc), - arrayData.collect().sortBy(_.data(0)).toSeq) + arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) checkAnswer( arrayData.orderBy('data.getItem(0).desc), - arrayData.collect().sortBy(_.data(0)).reverse.toSeq) + arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) checkAnswer( - mapData.orderBy('data.getItem(1).asc), - mapData.collect().sortBy(_.data(1)).toSeq) + arrayData.orderBy('data.getItem(1).asc), + arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) checkAnswer( - mapData.orderBy('data.getItem(1).desc), - mapData.collect().sortBy(_.data(1)).reverse.toSeq) + arrayData.orderBy('data.getItem(1).desc), + arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) } test("partition wide sorting") { @@ -147,19 +147,19 @@ class DslQuerySuite extends QueryTest { // (3, 2) checkAnswer( testData2.sortBy('a.asc, 'b.asc), - Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) + Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) checkAnswer( testData2.sortBy('a.asc, 'b.desc), - Seq((1,2), (1,1), (2,1), (2,2), (3,2), (3,1))) + Seq(Row(1,2), Row(1,1), Row(2,1), Row(2,2), Row(3,2), Row(3,1))) checkAnswer( testData2.sortBy('a.desc, 'b.desc), - Seq((2,1), (1,2), (1,1), (3,2), (3,1), (2,2))) + Seq(Row(2,1), Row(1,2), Row(1,1), Row(3,2), Row(3,1), Row(2,2))) checkAnswer( testData2.sortBy('a.desc, 'b.asc), - Seq((2,1), (1,1), (1,2), (3,1), (3,2), (2,2))) + Seq(Row(2,1), Row(1,1), Row(1,2), Row(3,1), Row(3,2), Row(2,2))) } test("limit") { @@ -169,11 +169,11 @@ class DslQuerySuite extends QueryTest { checkAnswer( arrayData.limit(1), - arrayData.take(1).toSeq) + arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) checkAnswer( mapData.limit(1), - mapData.take(1).toSeq) + mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) } test("SPARK-3395 limit distinct") { @@ -184,8 +184,8 @@ class DslQuerySuite extends QueryTest { .registerTempTable("onerow") checkAnswer( sql("select * from onerow inner join testData2 on onerow.a = testData2.a"), - (1, 1, 1, 1) :: - (1, 1, 1, 2) :: Nil) + Row(1, 1, 1, 1) :: + Row(1, 1, 1, 2) :: Nil) } test("SPARK-3858 generator qualifiers are discarded") { @@ -193,55 +193,55 @@ class DslQuerySuite extends QueryTest { arrayData.as('ad) .generate(Explode("data" :: Nil, 'data), alias = Some("ex")) .select("ex.data".attr), - Seq(1, 2, 3, 2, 3, 4).map(Seq(_))) + Seq(1, 2, 3, 2, 3, 4).map(Row(_))) } test("average") { checkAnswer( testData2.aggregate(avg('a)), - 2.0) + Row(2.0)) checkAnswer( testData2.aggregate(avg('a), sumDistinct('a)), // non-partial - (2.0, 6.0) :: Nil) + Row(2.0, 6.0) :: Nil) checkAnswer( decimalData.aggregate(avg('a)), - new java.math.BigDecimal(2.0)) + Row(new java.math.BigDecimal(2.0))) checkAnswer( decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial - (new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) checkAnswer( decimalData.aggregate(avg('a cast DecimalType(10, 2))), - new java.math.BigDecimal(2.0)) + Row(new java.math.BigDecimal(2.0))) checkAnswer( decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial - (new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) } test("null average") { checkAnswer( testData3.aggregate(avg('b)), - 2.0) + Row(2.0)) checkAnswer( testData3.aggregate(avg('b), countDistinct('b)), - (2.0, 1) :: Nil) + Row(2.0, 1)) checkAnswer( testData3.aggregate(avg('b), sumDistinct('b)), // non-partial - (2.0, 2.0) :: Nil) + Row(2.0, 2.0)) } test("zero average") { checkAnswer( emptyTableData.aggregate(avg('a)), - null) + Row(null)) checkAnswer( emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial - (null, null) :: Nil) + Row(null, null)) } test("count") { @@ -249,28 +249,28 @@ class DslQuerySuite extends QueryTest { checkAnswer( testData2.aggregate(count('a), sumDistinct('a)), // non-partial - (6, 6.0) :: Nil) + Row(6, 6.0)) } test("null count") { checkAnswer( testData3.groupBy('a)('a, count('b)), - Seq((1,0), (2, 1)) + Seq(Row(1,0), Row(2, 1)) ) checkAnswer( testData3.groupBy('a)('a, count('a + 'b)), - Seq((1,0), (2, 1)) + Seq(Row(1,0), Row(2, 1)) ) checkAnswer( testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)), - (2, 1, 2, 2, 1) :: Nil + Row(2, 1, 2, 2, 1) ) checkAnswer( testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial - (1, 1, 2) :: Nil + Row(1, 1, 2) ) } @@ -279,28 +279,28 @@ class DslQuerySuite extends QueryTest { checkAnswer( emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial - (0, null) :: Nil) + Row(0, null)) } test("zero sum") { checkAnswer( emptyTableData.aggregate(sum('a)), - null) + Row(null)) } test("zero sum distinct") { checkAnswer( emptyTableData.aggregate(sumDistinct('a)), - null) + Row(null)) } test("except") { checkAnswer( lowerCaseData.except(upperCaseData), - (1, "a") :: - (2, "b") :: - (3, "c") :: - (4, "d") :: Nil) + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer(lowerCaseData.except(lowerCaseData), Nil) checkAnswer(upperCaseData.except(upperCaseData), Nil) } @@ -308,10 +308,10 @@ class DslQuerySuite extends QueryTest { test("intersect") { checkAnswer( lowerCaseData.intersect(lowerCaseData), - (1, "a") :: - (2, "b") :: - (3, "c") :: - (4, "d") :: Nil) + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) } @@ -321,75 +321,75 @@ class DslQuerySuite extends QueryTest { checkAnswer( // SELECT *, foo(key, value) FROM testData testData.select(Star(None), foo.call('key, 'value)).limit(3), - (1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil + Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil ) } test("sqrt") { checkAnswer( testData.select(sqrt('key)).orderBy('key asc), - (1 to 100).map(n => Seq(math.sqrt(n))) + (1 to 100).map(n => Row(math.sqrt(n))) ) checkAnswer( testData.select(sqrt('value), 'key).orderBy('key asc, 'value asc), - (1 to 100).map(n => Seq(math.sqrt(n), n)) + (1 to 100).map(n => Row(math.sqrt(n), n)) ) checkAnswer( testData.select(sqrt(Literal(null))), - (1 to 100).map(_ => Seq(null)) + (1 to 100).map(_ => Row(null)) ) } test("abs") { checkAnswer( testData.select(abs('key)).orderBy('key asc), - (1 to 100).map(n => Seq(n)) + (1 to 100).map(n => Row(n)) ) checkAnswer( negativeData.select(abs('key)).orderBy('key desc), - (1 to 100).map(n => Seq(n)) + (1 to 100).map(n => Row(n)) ) checkAnswer( testData.select(abs(Literal(null))), - (1 to 100).map(_ => Seq(null)) + (1 to 100).map(_ => Row(null)) ) } test("upper") { checkAnswer( lowerCaseData.select(upper('l)), - ('a' to 'd').map(c => Seq(c.toString.toUpperCase())) + ('a' to 'd').map(c => Row(c.toString.toUpperCase())) ) checkAnswer( testData.select(upper('value), 'key), - (1 to 100).map(n => Seq(n.toString, n)) + (1 to 100).map(n => Row(n.toString, n)) ) checkAnswer( testData.select(upper(Literal(null))), - (1 to 100).map(n => Seq(null)) + (1 to 100).map(n => Row(null)) ) } test("lower") { checkAnswer( upperCaseData.select(lower('L)), - ('A' to 'F').map(c => Seq(c.toString.toLowerCase())) + ('A' to 'F').map(c => Row(c.toString.toLowerCase())) ) checkAnswer( testData.select(lower('value), 'key), - (1 to 100).map(n => Seq(n.toString, n)) + (1 to 100).map(n => Row(n.toString, n)) ) checkAnswer( testData.select(lower(Literal(null))), - (1 to 100).map(n => Seq(null)) + (1 to 100).map(n => Row(null)) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index e5ab16f9dd661..cd36da7751e83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -117,10 +117,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( upperCaseData.join(lowerCaseData, Inner).where('n === 'N), Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d") + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d") )) } @@ -128,10 +128,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)), Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d") + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d") )) } @@ -140,10 +140,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val y = testData2.where('a === 1).as('y) checkAnswer( x.join(y).where("x.a".attr === "y.a".attr), - (1,1,1,1) :: - (1,1,1,2) :: - (1,2,1,1) :: - (1,2,1,2) :: Nil + Row(1,1,1,1) :: + Row(1,1,1,2) :: + Row(1,2,1,1) :: + Row(1,2,1,2) :: Nil ) } @@ -163,54 +163,54 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr), testData.flatMap( - row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq) + row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } test("cartisian product join") { checkAnswer( testData3.join(testData3), - (1, null, 1, null) :: - (1, null, 2, 2) :: - (2, 2, 1, null) :: - (2, 2, 2, 2) :: Nil) + Row(1, null, 1, null) :: + Row(1, null, 2, 2) :: + Row(2, 2, 1, null) :: + Row(2, 2, 2, 2) :: Nil) } test("left outer join") { checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)), - (1, "A", 1, "a") :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) + Row(1, "A", 1, "a") :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)), - (1, "A", null, null) :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)), - (1, "A", null, null) :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)), - (1, "A", 1, "a") :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) + Row(1, "A", 1, "a") :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) // Make sure we are choosing left.outputPartitioning as the // outputPartitioning for the outer join operator. @@ -221,12 +221,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) |GROUP BY l.N """.stripMargin), - (1, 1) :: - (2, 1) :: - (3, 1) :: - (4, 1) :: - (5, 1) :: - (6, 1) :: Nil) + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: Nil) checkAnswer( sql( @@ -235,42 +235,42 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) |GROUP BY r.a """.stripMargin), - (null, 6) :: Nil) + Row(null, 6) :: Nil) } test("right outer join") { checkAnswer( lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)), - (1, "a", 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "a", 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) checkAnswer( lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)), - (null, null, 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(null, null, 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) checkAnswer( lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)), - (null, null, 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(null, null, 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) checkAnswer( lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)), - (1, "a", 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "a", 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) // Make sure we are choosing right.outputPartitioning as the // outputPartitioning for the outer join operator. @@ -281,7 +281,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY l.a """.stripMargin), - (null, 6) :: Nil) + Row(null, 6)) checkAnswer( sql( @@ -290,12 +290,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY r.N """.stripMargin), - (1, 1) :: - (2, 1) :: - (3, 1) :: - (4, 1) :: - (5, 1) :: - (6, 1) :: Nil) + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: Nil) } test("full outer join") { @@ -307,32 +307,32 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) checkAnswer( left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", null, null) :: - (null, null, 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", null, null) :: + Row(null, null, 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) checkAnswer( left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", null, null) :: - (null, null, 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", null, null) :: + Row(null, null, 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. checkAnswer( @@ -342,7 +342,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY l.a """.stripMargin), - (null, 10) :: Nil) + Row(null, 10)) checkAnswer( sql( @@ -351,13 +351,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY r.N """.stripMargin), - (1, 1) :: - (2, 1) :: - (3, 1) :: - (4, 1) :: - (5, 1) :: - (6, 1) :: - (null, 4) :: Nil) + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: + Row(null, 4) :: Nil) checkAnswer( sql( @@ -366,13 +366,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) |GROUP BY l.N """.stripMargin), - (1, 1) :: - (2, 1) :: - (3, 1) :: - (4, 1) :: - (5, 1) :: - (6, 1) :: - (null, 4) :: Nil) + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: + Row(null, 4) :: Nil) checkAnswer( sql( @@ -381,7 +381,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) |GROUP BY r.a """.stripMargin), - (null, 10) :: Nil) + Row(null, 10)) } test("broadcasted left semi join operator selection") { @@ -412,12 +412,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("left semi join") { val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(rdd, - (1, 1) :: - (1, 2) :: - (2, 1) :: - (2, 2) :: - (3, 1) :: - (3, 2) :: Nil) + Row(1, 1) :: + Row(1, 2) :: + Row(2, 1) :: + Row(2, 2) :: + Row(3, 1) :: + Row(3, 2) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 68ddecc7f610d..42a21c148df53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -47,26 +47,17 @@ class QueryTest extends PlanTest { * @param rdd the [[SchemaRDD]] to be executed * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. */ - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Any): Unit = { - val convertedAnswer = expectedAnswer match { - case s: Seq[_] if s.isEmpty => s - case s: Seq[_] if s.head.isInstanceOf[Product] && - !s.head.isInstanceOf[Seq[_]] => s.map(_.asInstanceOf[Product].productIterator.toIndexedSeq) - case s: Seq[_] => s - case singleItem => Seq(Seq(singleItem)) - } - + protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = { val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty - def prepareAnswer(answer: Seq[Any]): Seq[Any] = { + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). - val converted = answer.map { - case s: Seq[_] => s.map { + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { case d: java.math.BigDecimal => BigDecimal(d) case o => o - } - case o => o + }) } if (!isSorted) converted.sortBy(_.toString) else converted } @@ -82,7 +73,7 @@ class QueryTest extends PlanTest { """.stripMargin) } - if (prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) { + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { fail(s""" |Results do not match for query: |${rdd.logicalPlan} @@ -92,15 +83,19 @@ class QueryTest extends PlanTest { |${rdd.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${convertedAnswer.size} ==" +: - prepareAnswer(convertedAnswer).map(_.toString), + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString), s"== Spark Answer - ${sparkAnswer.size} ==" +: prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} """.stripMargin) } } - def sqlTest(sqlString: String, expectedAnswer: Any)(implicit sqlContext: SQLContext): Unit = { + protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = { + checkAnswer(rdd, Seq(expectedAnswer)) + } + + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { test(sqlString) { checkAnswer(sqlContext.sql(sqlString), expectedAnswer) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 54fabc5c915fb..03b44ca1d6695 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -46,7 +46,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { checkAnswer( sql("SELECT a FROM testData2 SORT BY a"), - Seq(1, 1, 2 ,2 ,3 ,3).map(Seq(_)) + Seq(1, 1, 2 ,2 ,3 ,3).map(Row(_)) ) } @@ -70,13 +70,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3176 Added Parser of SQL ABS()") { checkAnswer( sql("SELECT ABS(-1.3)"), - 1.3) + Row(1.3)) checkAnswer( sql("SELECT ABS(0.0)"), - 0.0) + Row(0.0)) checkAnswer( sql("SELECT ABS(2.5)"), - 2.5) + Row(2.5)) } test("aggregation with codegen") { @@ -89,13 +89,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3176 Added Parser of SQL LAST()") { checkAnswer( sql("SELECT LAST(n) FROM lowerCaseData"), - 4) + Row(4)) } test("SPARK-2041 column name equals tablename") { checkAnswer( sql("SELECT tableName FROM tableName"), - "test") + Row("test")) } test("SQRT") { @@ -115,40 +115,40 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-2407 Added Parser of SQL SUBSTR()") { checkAnswer( sql("SELECT substr(tableName, 1, 2) FROM tableName"), - "te") + Row("te")) checkAnswer( sql("SELECT substr(tableName, 3) FROM tableName"), - "st") + Row("st")) checkAnswer( sql("SELECT substring(tableName, 1, 2) FROM tableName"), - "te") + Row("te")) checkAnswer( sql("SELECT substring(tableName, 3) FROM tableName"), - "st") + Row("st")) } test("SPARK-3173 Timestamp support in the parser") { checkAnswer(sql( "SELECT time FROM timestamps WHERE time=CAST('1970-01-01 00:00:00.001' AS TIMESTAMP)"), - Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))) + Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) checkAnswer(sql( "SELECT time FROM timestamps WHERE time='1970-01-01 00:00:00.001'"), - Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))) + Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) checkAnswer(sql( "SELECT time FROM timestamps WHERE '1970-01-01 00:00:00.001'=time"), - Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))) + Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) checkAnswer(sql( """SELECT time FROM timestamps WHERE time<'1970-01-01 00:00:00.003' AND time>'1970-01-01 00:00:00.001'"""), - Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))) + Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002"))) checkAnswer(sql( "SELECT time FROM timestamps WHERE time IN ('1970-01-01 00:00:00.001','1970-01-01 00:00:00.002')"), - Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")), - Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))) + Seq(Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")), + Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))) checkAnswer(sql( "SELECT time FROM timestamps WHERE time='123'"), @@ -158,13 +158,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("index into array") { checkAnswer( sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"), - arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq) + arrayData.map(d => Row(d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect()) } test("left semi greater than predicate") { checkAnswer( sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), - Seq((3,1), (3,2)) + Seq(Row(3,1), Row(3,2)) ) } @@ -173,7 +173,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql( "SELECT nestedData, nestedData[0][0], nestedData[0][0] + nestedData[0][1] FROM arrayData"), arrayData.map(d => - (d.nestedData, + Row(d.nestedData, d.nestedData(0)(0), d.nestedData(0)(0) + d.nestedData(0)(1))).collect().toSeq) } @@ -181,13 +181,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("agg") { checkAnswer( sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), - Seq((1,3),(2,3),(3,3))) + Seq(Row(1,3), Row(2,3), Row(3,3))) } test("aggregates with nulls") { checkAnswer( sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), - (1, 3, 2, 6, 3) :: Nil + Row(1, 3, 2, 6, 3) ) } @@ -200,29 +200,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("simple select") { checkAnswer( sql("SELECT value FROM testData WHERE key = 1"), - Seq(Seq("1"))) + Row("1")) } def sortTest() = { checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), - Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) + Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"), - Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1))) + Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"), - Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1))) + Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"), - Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) + Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) checkAnswer( sql("SELECT b FROM binaryData ORDER BY a ASC"), - (1 to 5).map(Row(_)).toSeq) + (1 to 5).map(Row(_))) checkAnswer( sql("SELECT b FROM binaryData ORDER BY a DESC"), @@ -230,19 +230,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT * FROM arrayData ORDER BY data[0] ASC"), - arrayData.collect().sortBy(_.data(0)).toSeq) + arrayData.collect().sortBy(_.data(0)).map(Row.fromTuple).toSeq) checkAnswer( sql("SELECT * FROM arrayData ORDER BY data[0] DESC"), - arrayData.collect().sortBy(_.data(0)).reverse.toSeq) + arrayData.collect().sortBy(_.data(0)).reverse.map(Row.fromTuple).toSeq) checkAnswer( sql("SELECT * FROM mapData ORDER BY data[1] ASC"), - mapData.collect().sortBy(_.data(1)).toSeq) + mapData.collect().sortBy(_.data(1)).map(Row.fromTuple).toSeq) checkAnswer( sql("SELECT * FROM mapData ORDER BY data[1] DESC"), - mapData.collect().sortBy(_.data(1)).reverse.toSeq) + mapData.collect().sortBy(_.data(1)).reverse.map(Row.fromTuple).toSeq) } test("sorting") { @@ -266,94 +266,94 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT * FROM arrayData LIMIT 1"), - arrayData.collect().take(1).toSeq) + arrayData.collect().take(1).map(Row.fromTuple).toSeq) checkAnswer( sql("SELECT * FROM mapData LIMIT 1"), - mapData.collect().take(1).toSeq) + mapData.collect().take(1).map(Row.fromTuple).toSeq) } test("from follow multiple brackets") { checkAnswer(sql( "select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1"), - 1 + Row(1) ) checkAnswer(sql( "select key from (select * from testData) x limit 1"), - 1 + Row(1) ) checkAnswer(sql( "select key from (select * from testData limit 1 union all select * from testData limit 1) x limit 1"), - 1 + Row(1) ) } test("average") { checkAnswer( sql("SELECT AVG(a) FROM testData2"), - 2.0) + Row(2.0)) } test("average overflow") { checkAnswer( sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), - Seq((2147483645.0,1),(2.0,2))) + Seq(Row(2147483645.0,1), Row(2.0,2))) } test("count") { checkAnswer( sql("SELECT COUNT(*) FROM testData2"), - testData2.count()) + Row(testData2.count())) } test("count distinct") { checkAnswer( sql("SELECT COUNT(DISTINCT b) FROM testData2"), - 2) + Row(2)) } test("approximate count distinct") { checkAnswer( sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"), - 3) + Row(3)) } test("approximate count distinct with user provided standard deviation") { checkAnswer( sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"), - 3) + Row(3)) } test("null count") { checkAnswer( sql("SELECT a, COUNT(b) FROM testData3 GROUP BY a"), - Seq((1, 0), (2, 1))) + Seq(Row(1, 0), Row(2, 1))) checkAnswer( sql("SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), - (2, 1, 2, 2, 1) :: Nil) + Row(2, 1, 2, 2, 1)) } test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d"))) + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d"))) } test("inner join ON, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON n = N"), Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d"))) + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d"))) } test("inner join, where, multiple matches") { @@ -363,10 +363,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | (SELECT * FROM testData2 WHERE a = 1) x JOIN | (SELECT * FROM testData2 WHERE a = 1) y |WHERE x.a = y.a""".stripMargin), - (1,1,1,1) :: - (1,1,1,2) :: - (1,2,1,1) :: - (1,2,1,2) :: Nil) + Row(1,1,1,1) :: + Row(1,1,1,2) :: + Row(1,2,1,1) :: + Row(1,2,1,2) :: Nil) } test("inner join, no matches") { @@ -397,38 +397,38 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | SELECT * FROM testData) y |WHERE x.key = y.key""".stripMargin), testData.flatMap( - row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq) + row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } ignore("cartesian product join") { checkAnswer( testData3.join(testData3), - (1, null, 1, null) :: - (1, null, 2, 2) :: - (2, 2, 1, null) :: - (2, 2, 2, 2) :: Nil) + Row(1, null, 1, null) :: + Row(1, null, 2, 2) :: + Row(2, 2, 1, null) :: + Row(2, 2, 2, 2) :: Nil) } test("left outer join") { checkAnswer( sql("SELECT * FROM upperCaseData LEFT OUTER JOIN lowerCaseData ON n = N"), - (1, "A", 1, "a") :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) + Row(1, "A", 1, "a") :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) } test("right outer join") { checkAnswer( sql("SELECT * FROM lowerCaseData RIGHT OUTER JOIN upperCaseData ON n = N"), - (1, "a", 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "a", 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) } test("full outer join") { @@ -440,12 +440,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | (SELECT * FROM upperCaseData WHERE N >= 3) rightTable | ON leftTable.N = rightTable.N """.stripMargin), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", 3, "C") :: + Row (4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) } test("SPARK-3349 partitioning after limit") { @@ -457,12 +457,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .registerTempTable("subset2") checkAnswer( sql("SELECT * FROM lowerCaseData INNER JOIN subset1 ON subset1.n = lowerCaseData.n"), - (3, "c", 3) :: - (4, "d", 4) :: Nil) + Row(3, "c", 3) :: + Row(4, "d", 4) :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"), - (1, "a", 1) :: - (2, "b", 2) :: Nil) + Row(1, "a", 1) :: + Row(2, "b", 2) :: Nil) } test("mixed-case keywords") { @@ -474,28 +474,28 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | (sElEcT * FROM upperCaseData whERe N >= 3) rightTable | oN leftTable.N = rightTable.N """.stripMargin), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) } test("select with table name as qualifier") { checkAnswer( sql("SELECT testData.value FROM testData WHERE testData.key = 1"), - Seq(Seq("1"))) + Row("1")) } test("inner join ON with table name as qualifier") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON lowerCaseData.n = upperCaseData.N"), Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d"))) + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d"))) } test("qualified select with inner join ON with table name as qualifier") { @@ -503,72 +503,72 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT upperCaseData.N, upperCaseData.L FROM upperCaseData JOIN lowerCaseData " + "ON lowerCaseData.n = upperCaseData.N"), Seq( - (1, "A"), - (2, "B"), - (3, "C"), - (4, "D"))) + Row(1, "A"), + Row(2, "B"), + Row(3, "C"), + Row(4, "D"))) } test("system function upper()") { checkAnswer( sql("SELECT n,UPPER(l) FROM lowerCaseData"), Seq( - (1, "A"), - (2, "B"), - (3, "C"), - (4, "D"))) + Row(1, "A"), + Row(2, "B"), + Row(3, "C"), + Row(4, "D"))) checkAnswer( sql("SELECT n, UPPER(s) FROM nullStrings"), Seq( - (1, "ABC"), - (2, "ABC"), - (3, null))) + Row(1, "ABC"), + Row(2, "ABC"), + Row(3, null))) } test("system function lower()") { checkAnswer( sql("SELECT N,LOWER(L) FROM upperCaseData"), Seq( - (1, "a"), - (2, "b"), - (3, "c"), - (4, "d"), - (5, "e"), - (6, "f"))) + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "d"), + Row(5, "e"), + Row(6, "f"))) checkAnswer( sql("SELECT n, LOWER(s) FROM nullStrings"), Seq( - (1, "abc"), - (2, "abc"), - (3, null))) + Row(1, "abc"), + Row(2, "abc"), + Row(3, null))) } test("UNION") { checkAnswer( sql("SELECT * FROM lowerCaseData UNION SELECT * FROM upperCaseData"), - (1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") :: - (4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil) + Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") :: + Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData UNION SELECT * FROM lowerCaseData"), - (1, "a") :: (2, "b") :: (3, "c") :: (4, "d") :: Nil) + Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Row(4, "d") :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData UNION ALL SELECT * FROM lowerCaseData"), - (1, "a") :: (1, "a") :: (2, "b") :: (2, "b") :: (3, "c") :: (3, "c") :: - (4, "d") :: (4, "d") :: Nil) + Row(1, "a") :: Row(1, "a") :: Row(2, "b") :: Row(2, "b") :: Row(3, "c") :: Row(3, "c") :: + Row(4, "d") :: Row(4, "d") :: Nil) } test("UNION with column mismatches") { // Column name mismatches are allowed. checkAnswer( sql("SELECT n,l FROM lowerCaseData UNION SELECT N as x1, L as x2 FROM upperCaseData"), - (1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") :: - (4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil) + Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") :: + Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil) // Column type mismatches are not allowed, forcing a type coercion. checkAnswer( sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"), - ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Tuple1(_))) + ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_))) // Column type mismatches where a coercion is not possible, in this case between integer // and array types, trigger a TreeNodeException. intercept[TreeNodeException[_]] { @@ -579,10 +579,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("EXCEPT") { checkAnswer( sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData"), - (1, "a") :: - (2, "b") :: - (3, "c") :: - (4, "d") :: Nil) + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData"), Nil) checkAnswer( @@ -592,10 +592,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("INTERSECT") { checkAnswer( sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM lowerCaseData"), - (1, "a") :: - (2, "b") :: - (3, "c") :: - (4, "d") :: Nil) + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM upperCaseData"), Nil) } @@ -613,25 +613,25 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Seq(Seq(s"$testKey=$testVal")) + Row(s"$testKey=$testVal") ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), Seq( - Seq(s"$testKey=$testVal"), - Seq(s"${testKey + testKey}=${testVal + testVal}")) + Row(s"$testKey=$testVal"), + Row(s"${testKey + testKey}=${testVal + testVal}")) ) // "set key" checkAnswer( sql(s"SET $testKey"), - Seq(Seq(s"$testKey=$testVal")) + Row(s"$testKey=$testVal") ) checkAnswer( sql(s"SET $nonexistentKey"), - Seq(Seq(s"$nonexistentKey=")) + Row(s"$nonexistentKey=") ) conf.clear() } @@ -655,17 +655,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { schemaRDD1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), - (1, "A1", true, null) :: - (2, "B2", false, null) :: - (3, "C3", true, null) :: - (4, "D4", true, 2147483644) :: Nil) + Row(1, "A1", true, null) :: + Row(2, "B2", false, null) :: + Row(3, "C3", true, null) :: + Row(4, "D4", true, 2147483644) :: Nil) checkAnswer( sql("SELECT f1, f4 FROM applySchema1"), - (1, null) :: - (2, null) :: - (3, null) :: - (4, 2147483644) :: Nil) + Row(1, null) :: + Row(2, null) :: + Row(3, null) :: + Row(4, 2147483644) :: Nil) val schema2 = StructType( StructField("f1", StructType( @@ -685,17 +685,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { schemaRDD2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), - (Seq(1, true), Map("A1" -> null)) :: - (Seq(2, false), Map("B2" -> null)) :: - (Seq(3, true), Map("C3" -> null)) :: - (Seq(4, true), Map("D4" -> 2147483644)) :: Nil) + Row(Row(1, true), Map("A1" -> null)) :: + Row(Row(2, false), Map("B2" -> null)) :: + Row(Row(3, true), Map("C3" -> null)) :: + Row(Row(4, true), Map("D4" -> 2147483644)) :: Nil) checkAnswer( sql("SELECT f1.f11, f2['D4'] FROM applySchema2"), - (1, null) :: - (2, null) :: - (3, null) :: - (4, 2147483644) :: Nil) + Row(1, null) :: + Row(2, null) :: + Row(3, null) :: + Row(4, 2147483644) :: Nil) // The value of a MapType column can be a mutable map. val rowRDD3 = unparsedStrings.map { r => @@ -711,26 +711,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), - (1, null) :: - (2, null) :: - (3, null) :: - (4, 2147483644) :: Nil) + Row(1, null) :: + Row(2, null) :: + Row(3, null) :: + Row(4, 2147483644) :: Nil) } test("SPARK-3423 BETWEEN") { checkAnswer( sql("SELECT key, value FROM testData WHERE key BETWEEN 5 and 7"), - Seq((5, "5"), (6, "6"), (7, "7")) + Seq(Row(5, "5"), Row(6, "6"), Row(7, "7")) ) checkAnswer( sql("SELECT key, value FROM testData WHERE key BETWEEN 7 and 7"), - Seq((7, "7")) + Row(7, "7") ) checkAnswer( sql("SELECT key, value FROM testData WHERE key BETWEEN 9 and 7"), - Seq() + Nil ) } @@ -738,7 +738,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // TODO Ensure true/false string letter casing is consistent with Hive in all cases. checkAnswer( sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), - ("true", "false") :: Nil) + Row("true", "false")) } test("metadata is propagated correctly") { @@ -768,17 +768,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3371 Renaming a function expression with group by gives error") { udf.register("len", (s: String) => s.length) checkAnswer( - sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1) + sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), + Row(1)) } test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") { checkAnswer( - sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), 1) + sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), + Row(1)) } test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") { checkAnswer( - sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1) + sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), + Row(1)) } test("throw errors for non-aggregate attributes with aggregation") { @@ -808,130 +811,131 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("Test to check we can use Long.MinValue") { checkAnswer( - sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Long.MinValue + sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Row(Long.MinValue) ) checkAnswer( - sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"), (1 to 100).map(Row(_)).toSeq + sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"), + (1 to 100).map(Row(_)).toSeq ) } test("Floating point number format") { checkAnswer( - sql("SELECT 0.3"), 0.3 + sql("SELECT 0.3"), Row(0.3) ) checkAnswer( - sql("SELECT -0.8"), -0.8 + sql("SELECT -0.8"), Row(-0.8) ) checkAnswer( - sql("SELECT .5"), 0.5 + sql("SELECT .5"), Row(0.5) ) checkAnswer( - sql("SELECT -.18"), -0.18 + sql("SELECT -.18"), Row(-0.18) ) } test("Auto cast integer type") { checkAnswer( - sql(s"SELECT ${Int.MaxValue + 1L}"), Int.MaxValue + 1L + sql(s"SELECT ${Int.MaxValue + 1L}"), Row(Int.MaxValue + 1L) ) checkAnswer( - sql(s"SELECT ${Int.MinValue - 1L}"), Int.MinValue - 1L + sql(s"SELECT ${Int.MinValue - 1L}"), Row(Int.MinValue - 1L) ) checkAnswer( - sql("SELECT 9223372036854775808"), new java.math.BigDecimal("9223372036854775808") + sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808")) ) checkAnswer( - sql("SELECT -9223372036854775809"), new java.math.BigDecimal("-9223372036854775809") + sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809")) ) } test("Test to check we can apply sign to expression") { checkAnswer( - sql("SELECT -100"), -100 + sql("SELECT -100"), Row(-100) ) checkAnswer( - sql("SELECT +230"), 230 + sql("SELECT +230"), Row(230) ) checkAnswer( - sql("SELECT -5.2"), -5.2 + sql("SELECT -5.2"), Row(-5.2) ) checkAnswer( - sql("SELECT +6.8"), 6.8 + sql("SELECT +6.8"), Row(6.8) ) checkAnswer( - sql("SELECT -key FROM testData WHERE key = 2"), -2 + sql("SELECT -key FROM testData WHERE key = 2"), Row(-2) ) checkAnswer( - sql("SELECT +key FROM testData WHERE key = 3"), 3 + sql("SELECT +key FROM testData WHERE key = 3"), Row(3) ) checkAnswer( - sql("SELECT -(key + 1) FROM testData WHERE key = 1"), -2 + sql("SELECT -(key + 1) FROM testData WHERE key = 1"), Row(-2) ) checkAnswer( - sql("SELECT - key + 1 FROM testData WHERE key = 10"), -9 + sql("SELECT - key + 1 FROM testData WHERE key = 10"), Row(-9) ) checkAnswer( - sql("SELECT +(key + 5) FROM testData WHERE key = 5"), 10 + sql("SELECT +(key + 5) FROM testData WHERE key = 5"), Row(10) ) checkAnswer( - sql("SELECT -MAX(key) FROM testData"), -100 + sql("SELECT -MAX(key) FROM testData"), Row(-100) ) checkAnswer( - sql("SELECT +MAX(key) FROM testData"), 100 + sql("SELECT +MAX(key) FROM testData"), Row(100) ) checkAnswer( - sql("SELECT - (-10)"), 10 + sql("SELECT - (-10)"), Row(10) ) checkAnswer( - sql("SELECT + (-key) FROM testData WHERE key = 32"), -32 + sql("SELECT + (-key) FROM testData WHERE key = 32"), Row(-32) ) checkAnswer( - sql("SELECT - (+Max(key)) FROM testData"), -100 + sql("SELECT - (+Max(key)) FROM testData"), Row(-100) ) checkAnswer( - sql("SELECT - - 3"), 3 + sql("SELECT - - 3"), Row(3) ) checkAnswer( - sql("SELECT - + 20"), -20 + sql("SELECT - + 20"), Row(-20) ) checkAnswer( - sql("SELEcT - + 45"), -45 + sql("SELEcT - + 45"), Row(-45) ) checkAnswer( - sql("SELECT + + 100"), 100 + sql("SELECT + + 100"), Row(100) ) checkAnswer( - sql("SELECT - - Max(key) FROM testData"), 100 + sql("SELECT - - Max(key) FROM testData"), Row(100) ) checkAnswer( - sql("SELECT + - key FROM testData WHERE key = 33"), -33 + sql("SELECT + - key FROM testData WHERE key = 33"), Row(-33) ) } @@ -943,7 +947,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { |JOIN testData b ON a.key = b.key |JOIN testData c ON a.key = c.key """.stripMargin), - (1 to 100).map(i => Seq(i, i, i))) + (1 to 100).map(i => Row(i, i, i))) } test("SPARK-3483 Special chars in column names") { @@ -953,19 +957,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3814 Support Bitwise & operator") { - checkAnswer(sql("SELECT key&1 FROM testData WHERE key = 1 "), 1) + checkAnswer(sql("SELECT key&1 FROM testData WHERE key = 1 "), Row(1)) } test("SPARK-3814 Support Bitwise | operator") { - checkAnswer(sql("SELECT key|0 FROM testData WHERE key = 1 "), 1) + checkAnswer(sql("SELECT key|0 FROM testData WHERE key = 1 "), Row(1)) } test("SPARK-3814 Support Bitwise ^ operator") { - checkAnswer(sql("SELECT key^0 FROM testData WHERE key = 1 "), 1) + checkAnswer(sql("SELECT key^0 FROM testData WHERE key = 1 "), Row(1)) } test("SPARK-3814 Support Bitwise ~ operator") { - checkAnswer(sql("SELECT ~key FROM testData WHERE key = 1 "), -2) + checkAnswer(sql("SELECT ~key FROM testData WHERE key = 1 "), Row(-2)) } test("SPARK-4120 Join of multiple tables does not work in SparkSQL") { @@ -975,40 +979,40 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { |FROM testData a,testData b,testData c |where a.key = b.key and a.key = c.key """.stripMargin), - (1 to 100).map(i => Seq(i, i, i))) + (1 to 100).map(i => Row(i, i, i))) } test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { checkAnswer(sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"), - (11 to 100).map(i => Seq(i))) + (11 to 100).map(i => Row(i))) } test("SPARK-4207 Query which has syntax like 'not like' is not working in Spark SQL") { checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"), - (1 to 99).map(i => Seq(i))) + (1 to 99).map(i => Row(i))) } test("SPARK-4322 Grouping field with struct field as sub expression") { jsonRDD(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data") - checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), 1) + checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) dropTempTable("data") jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") - checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), 2) + checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) dropTempTable("data") } test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { checkAnswer( sql("SELECT a + b FROM testData2 ORDER BY a"), - Seq(2, 3, 3 ,4 ,4 ,5).map(Seq(_)) + Seq(2, 3, 3 ,4 ,4 ,5).map(Row(_)) ) } test("oder by asc by default when not specify ascending and descending") { checkAnswer( sql("SELECT a, b FROM testData2 ORDER BY a desc, b"), - Seq((3, 1), (3, 2), (2, 1), (2,2), (1, 1), (1, 2)) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2,2), Row(1, 1), Row(1, 2)) ) } @@ -1021,13 +1025,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { rdd2.registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), - (1 to 2).map(i => Seq(i))) + (1 to 2).map(i => Row(i))) } test("Multi-column COUNT(DISTINCT ...)") { val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.registerTempTable("distinctData") - checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), 2) + checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index ee381da491054..a015884bae282 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -82,7 +82,7 @@ class ScalaReflectionRelationSuite extends FunSuite { rdd.registerTempTable("reflectData") assert(sql("SELECT * FROM reflectData").collect().head === - Seq("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, + Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))) } @@ -91,7 +91,7 @@ class ScalaReflectionRelationSuite extends FunSuite { val rdd = sparkContext.parallelize(data :: Nil) rdd.registerTempTable("reflectNullData") - assert(sql("SELECT * FROM reflectNullData").collect().head === Seq.fill(7)(null)) + assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } test("query case class RDD with Nones") { @@ -99,7 +99,7 @@ class ScalaReflectionRelationSuite extends FunSuite { val rdd = sparkContext.parallelize(data :: Nil) rdd.registerTempTable("reflectOptionalData") - assert(sql("SELECT * FROM reflectOptionalData").collect().head === Seq.fill(7)(null)) + assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } // Equality is broken for Arrays, so we test that separately. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 9be0b38e689ff..be2b34de077c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -42,8 +42,8 @@ class ColumnStatsSuite extends FunSuite { test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => - assert(actual === expected) + columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + case (actual, expected) => assert(actual === expected) } } @@ -54,7 +54,7 @@ class ColumnStatsSuite extends FunSuite { val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) - val values = rows.take(10).map(_.head.asInstanceOf[T#JvmType]) + val values = rows.take(10).map(_(0).asInstanceOf[T#JvmType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]] val stats = columnStats.collectedStatistics diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index d94729ba92360..e61f3c39631da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -49,7 +49,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key - }.toSeq) + }.map(Row.fromTuple)) } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { @@ -63,49 +63,49 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("SPARK-1678 regression: compression must not lose repeated values") { checkAnswer( sql("SELECT * FROM repeatedData"), - repeatedData.collect().toSeq) + repeatedData.collect().toSeq.map(Row.fromTuple)) cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), - repeatedData.collect().toSeq) + repeatedData.collect().toSeq.map(Row.fromTuple)) } test("with null values") { checkAnswer( sql("SELECT * FROM nullableRepeatedData"), - nullableRepeatedData.collect().toSeq) + nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), - nullableRepeatedData.collect().toSeq) + nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) } test("SPARK-2729 regression: timestamp data type") { checkAnswer( sql("SELECT time FROM timestamps"), - timestamps.collect().toSeq) + timestamps.collect().toSeq.map(Row.fromTuple)) cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), - timestamps.collect().toSeq) + timestamps.collect().toSeq.map(Row.fromTuple)) } test("SPARK-3320 regression: batched column buffer building should work with empty partitions") { checkAnswer( sql("SELECT * FROM withEmptyParts"), - withEmptyParts.collect().toSeq) + withEmptyParts.collect().toSeq.map(Row.fromTuple)) cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), - withEmptyParts.collect().toSeq) + withEmptyParts.collect().toSeq.map(Row.fromTuple)) } test("SPARK-4182 Caching complex types") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 592cafbbdc203..c3a3f8ddc3ebf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -108,7 +108,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be val queryExecution = schemaRdd.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { - schemaRdd.collect().map(_.head).toArray + schemaRdd.collect().map(_(0)).toArray } val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index d9e488e0ffd16..8b518f094174c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -34,7 +34,7 @@ class BooleanBitSetSuite extends FunSuite { val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN)) - val values = rows.map(_.head) + val values = rows.map(_(0)) rows.foreach(builder.appendFrom(_, 0)) val buffer = builder.build() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index c5b6fce5fd297..67007b8c093ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ +import org.apache.spark.sql.types._ class PlannerSuite extends FunSuite { test("unions are collapsed") { @@ -60,19 +61,62 @@ class PlannerSuite extends FunSuite { } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { - val origThreshold = conf.autoBroadcastJoinThreshold - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString) - - // Using a threshold that is definitely larger than the small testing table (b) below - val a = testData.as('a) - val b = testData.limit(3).as('b) - val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan + def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = { + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold.toString) + val fields = fieldTypes.zipWithIndex.map { + case (dataType, index) => StructField(s"c${index}", dataType, true) + } :+ StructField("key", IntegerType, true) + val schema = StructType(fields) + val row = Row.fromSeq(Seq.fill(fields.size)(null)) + val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil) + applySchema(rowRDD, schema).registerTempTable("testLimit") + + val planned = sql( + """ + |SELECT l.a, l.b + |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) + """.stripMargin).queryExecution.executedPlan + + val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } + val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + + assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + + dropTempTable("testLimit") + } - val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + val origThreshold = conf.autoBroadcastJoinThreshold - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + val simpleTypes = + NullType :: + BooleanType :: + ByteType :: + ShortType :: + IntegerType :: + LongType :: + FloatType :: + DoubleType :: + DecimalType(10, 5) :: + DecimalType.Unlimited :: + DateType :: + TimestampType :: + StringType :: + BinaryType :: Nil + + checkPlan(simpleTypes, newThreshold = 16434) + + val complexTypes = + ArrayType(DoubleType, true) :: + ArrayType(StringType, false) :: + MapType(IntegerType, StringType, true) :: + MapType(IntegerType, ArrayType(DoubleType), false) :: + StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", ArrayType(DoubleType), nullable = false), + StructField("c", DoubleType, nullable = false))) :: Nil + + checkPlan(complexTypes, newThreshold = 901617) setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala index 2cab5e0c44d92..272c0d4cb2335 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala @@ -59,7 +59,7 @@ class TgfSuite extends QueryTest { checkAnswer( inputData.generate(ExampleTGF()), Seq( - "michael is 29 years old" :: Nil, - "Next year, michael will be 30 years old" :: Nil)) + Row("michael is 29 years old"), + Row("Next year, michael will be 30 years old"))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 2bc9aede32f2a..94d14acccbb18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -229,13 +229,13 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), - (new java.math.BigDecimal("92233720368547758070"), - true, - 1.7976931348623157E308, - 10, - 21474836470L, - null, - "this is a simple string.") :: Nil + Row(new java.math.BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") ) } @@ -271,48 +271,49 @@ class JsonSuite extends QueryTest { // Access elements of a primitive array. checkAnswer( sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"), - ("str1", "str2", null) :: Nil + Row("str1", "str2", null) ) // Access an array of null values. checkAnswer( sql("select arrayOfNull from jsonTable"), - Seq(Seq(null, null, null, null)) :: Nil + Row(Seq(null, null, null, null)) ) // Access elements of a BigInteger array (we use DecimalType internally). checkAnswer( sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"), - (new java.math.BigDecimal("922337203685477580700"), - new java.math.BigDecimal("-922337203685477580800"), null) :: Nil + Row(new java.math.BigDecimal("922337203685477580700"), + new java.math.BigDecimal("-922337203685477580800"), null) ) // Access elements of an array of arrays. checkAnswer( sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"), - (Seq("1", "2", "3"), Seq("str1", "str2")) :: Nil + Row(Seq("1", "2", "3"), Seq("str1", "str2")) ) // Access elements of an array of arrays. checkAnswer( sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"), - (Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) :: Nil + Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) ) // Access elements of an array inside a filed with the type of ArrayType(ArrayType). checkAnswer( sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"), - ("str2", 2.1) :: Nil + Row("str2", 2.1) ) // Access elements of an array of structs. checkAnswer( sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " + "from jsonTable"), - (true :: "str1" :: null :: Nil, - false :: null :: null :: Nil, - null :: null :: null :: Nil, - null) :: Nil + Row( + Row(true, "str1", null), + Row(false, null, null), + Row(null, null, null), + null) ) // Access a struct and fields inside of it. @@ -327,13 +328,13 @@ class JsonSuite extends QueryTest { // Access an array field of a struct. checkAnswer( sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"), - (Seq(4, 5, 6), Seq("str1", "str2")) :: Nil + Row(Seq(4, 5, 6), Seq("str1", "str2")) ) // Access elements of an array field of a struct. checkAnswer( sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"), - (5, null) :: Nil + Row(5, null) ) } @@ -344,14 +345,14 @@ class JsonSuite extends QueryTest { // Right now, "field1" and "field2" are treated as aliases. We should fix it. checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), - (true, "str1") :: Nil + Row(true, "str1") ) // Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2. // Getting all values of a specific field from an array of structs. checkAnswer( sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), - (Seq(true, false), Seq("str1", null)) :: Nil + Row(Seq(true, false), Seq("str1", null)) ) } @@ -372,57 +373,57 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), - ("true", 11L, null, 1.1, "13.1", "str1") :: - ("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") :: - ("false", 21474836470L, new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: - (null, 21474836570L, new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil + Row("true", 11L, null, 1.1, "13.1", "str1") :: + Row("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") :: + Row("false", 21474836470L, new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: + Row(null, 21474836570L, new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil ) // Number and Boolean conflict: resolve the type as number in this query. checkAnswer( sql("select num_bool - 10 from jsonTable where num_bool > 11"), - 2 + Row(2) ) // Widening to LongType checkAnswer( sql("select num_num_1 - 100 from jsonTable where num_num_1 > 11"), - Seq(21474836370L) :: Seq(21474836470L) :: Nil + Row(21474836370L) :: Row(21474836470L) :: Nil ) checkAnswer( sql("select num_num_1 - 100 from jsonTable where num_num_1 > 10"), - Seq(-89) :: Seq(21474836370L) :: Seq(21474836470L) :: Nil + Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil ) // Widening to DecimalType checkAnswer( sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"), - Seq(new java.math.BigDecimal("21474836472.1")) :: Seq(new java.math.BigDecimal("92233720368547758071.2")) :: Nil + Row(new java.math.BigDecimal("21474836472.1")) :: Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil ) // Widening to DoubleType checkAnswer( sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"), - Seq(101.2) :: Seq(21474836471.2) :: Nil + Row(101.2) :: Row(21474836471.2) :: Nil ) // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 14"), - 92233720368547758071.2 + Row(92233720368547758071.2) ) // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"), - new java.math.BigDecimal("92233720368547758061.2").doubleValue + Row(new java.math.BigDecimal("92233720368547758061.2").doubleValue) ) // String and Boolean conflict: resolve the type as string. checkAnswer( sql("select * from jsonTable where str_bool = 'str1'"), - ("true", 11L, null, 1.1, "13.1", "str1") :: Nil + Row("true", 11L, null, 1.1, "13.1", "str1") ) } @@ -434,24 +435,24 @@ class JsonSuite extends QueryTest { // Number and Boolean conflict: resolve the type as boolean in this query. checkAnswer( sql("select num_bool from jsonTable where NOT num_bool"), - false + Row(false) ) checkAnswer( sql("select str_bool from jsonTable where NOT str_bool"), - false + Row(false) ) // Right now, the analyzer does not know that num_bool should be treated as a boolean. // Number and Boolean conflict: resolve the type as boolean in this query. checkAnswer( sql("select num_bool from jsonTable where num_bool"), - true + Row(true) ) checkAnswer( sql("select str_bool from jsonTable where str_bool"), - false + Row(false) ) // The plan of the following DSL is @@ -464,7 +465,7 @@ class JsonSuite extends QueryTest { jsonSchemaRDD. where('num_str > BigDecimal("92233720368547758060")). select('num_str + 1.2 as Symbol("num")), - new java.math.BigDecimal("92233720368547758061.2") + Row(new java.math.BigDecimal("92233720368547758061.2")) ) // The following test will fail. The type of num_str is StringType. @@ -475,7 +476,7 @@ class JsonSuite extends QueryTest { // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 13"), - Seq(14.3) :: Seq(92233720368547758071.2) :: Nil + Row(14.3) :: Row(92233720368547758071.2) :: Nil ) } @@ -496,10 +497,10 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), - (Seq(), "11", "[1,2,3]", Seq(null), "[]") :: - (null, """{"field":false}""", null, null, "{}") :: - (Seq(4, 5, 6), null, "str", Seq(null), "[7,8,9]") :: - (Seq(7), "{}","[str1,str2,33]", Seq("str"), """{"field":true}""") :: Nil + Row(Seq(), "11", "[1,2,3]", Row(null), "[]") :: + Row(null, """{"field":false}""", null, null, "{}") :: + Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") :: + Row(Seq(7), "{}","[str1,str2,33]", Row("str"), """{"field":true}""") :: Nil ) } @@ -518,16 +519,16 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), - Seq(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]", - """{"field":"str"}"""), Seq(Seq(214748364700L), Seq(1)), null) :: - Seq(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) :: - Seq(null, null, Seq("1", "2", "3")) :: Nil + Row(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]", + """{"field":"str"}"""), Seq(Row(214748364700L), Row(1)), null) :: + Row(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) :: + Row(null, null, Seq("1", "2", "3")) :: Nil ) // Treat an element as a number. checkAnswer( sql("select array1[0] + 1 from jsonTable where array1 is not null"), - 2 + Row(2) ) } @@ -568,13 +569,13 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), - (new java.math.BigDecimal("92233720368547758070"), + Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, 21474836470L, null, - "this is a simple string.") :: Nil + "this is a simple string.") ) } @@ -594,13 +595,13 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTableSQL"), - (new java.math.BigDecimal("92233720368547758070"), + Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, 21474836470L, null, - "this is a simple string.") :: Nil + "this is a simple string.") ) } @@ -626,13 +627,13 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable1"), - (new java.math.BigDecimal("92233720368547758070"), + Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, 21474836470L, null, - "this is a simple string.") :: Nil + "this is a simple string.") ) val jsonSchemaRDD2 = jsonRDD(primitiveFieldAndType, schema) @@ -643,13 +644,13 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable2"), - (new java.math.BigDecimal("92233720368547758070"), + Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, 21474836470L, null, - "this is a simple string.") :: Nil + "this is a simple string.") ) } @@ -659,7 +660,7 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), - (true, "str1") :: Nil + Row(true, "str1") ) checkAnswer( sql( @@ -667,7 +668,7 @@ class JsonSuite extends QueryTest { |select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1] |from jsonTable """.stripMargin), - ("str2", 6) :: Nil + Row("str2", 6) ) } @@ -681,7 +682,7 @@ class JsonSuite extends QueryTest { |select arrayOfArray1[0][0][0], arrayOfArray1[1][0][1], arrayOfArray1[1][1][0] |from jsonTable """.stripMargin), - (5, 7, 8) :: Nil + Row(5, 7, 8) ) checkAnswer( sql( @@ -690,7 +691,7 @@ class JsonSuite extends QueryTest { |arrayOfArray2[1][1][1].inner2[0], arrayOfArray2[2][0][0].inner3[0][0].inner4 |from jsonTable """.stripMargin), - ("str1", Nil, "str4", 2) :: Nil + Row("str1", Nil, "str4", 2) ) } @@ -704,10 +705,10 @@ class JsonSuite extends QueryTest { |select a, b, c |from jsonTable """.stripMargin), - ("str_a_1", null, null) :: - ("str_a_2", null, null) :: - (null, "str_b_3", null) :: - ("str_a_4", "str_b_4", "str_c_4") :: Nil + Row("str_a_1", null, null) :: + Row("str_a_2", null, null) :: + Row(null, "str_b_3", null) :: + Row("str_a_4", "str_b_4", "str_c_4") :: Nil ) } @@ -734,12 +735,12 @@ class JsonSuite extends QueryTest { |SELECT a, b, c, _unparsed |FROM jsonTable """.stripMargin), - (null, null, null, "{") :: - (null, null, null, "") :: - (null, null, null, """{"a":1, b:2}""") :: - (null, null, null, """{"a":{, b:3}""") :: - ("str_a_4", "str_b_4", "str_c_4", null) :: - (null, null, null, "]") :: Nil + Row(null, null, null, "{") :: + Row(null, null, null, "") :: + Row(null, null, null, """{"a":1, b:2}""") :: + Row(null, null, null, """{"a":{, b:3}""") :: + Row("str_a_4", "str_b_4", "str_c_4", null) :: + Row(null, null, null, "]") :: Nil ) checkAnswer( @@ -749,7 +750,7 @@ class JsonSuite extends QueryTest { |FROM jsonTable |WHERE _unparsed IS NULL """.stripMargin), - ("str_a_4", "str_b_4", "str_c_4") :: Nil + Row("str_a_4", "str_b_4", "str_c_4") ) checkAnswer( @@ -759,11 +760,11 @@ class JsonSuite extends QueryTest { |FROM jsonTable |WHERE _unparsed IS NOT NULL """.stripMargin), - Seq("{") :: - Seq("") :: - Seq("""{"a":1, b:2}""") :: - Seq("""{"a":{, b:3}""") :: - Seq("]") :: Nil + Row("{") :: + Row("") :: + Row("""{"a":1, b:2}""") :: + Row("""{"a":{, b:3}""") :: + Row("]") :: Nil ) TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) @@ -793,10 +794,10 @@ class JsonSuite extends QueryTest { |SELECT field1, field2, field3, field4 |FROM jsonTable """.stripMargin), - Seq(Seq(Seq(null), Seq(Seq(Seq("Test")))), null, null, null) :: - Seq(null, Seq(null, Seq(Seq(1))), null, null) :: - Seq(null, null, Seq(Seq(null), Seq(Seq("2"))), null) :: - Seq(null, null, null, Seq(Seq(null, Seq(1, 2, 3)))) :: Nil + Row(Seq(Seq(null), Seq(Seq(Seq("Test")))), null, null, null) :: + Row(null, Seq(null, Seq(Row(1))), null, null) :: + Row(null, null, Seq(Seq(null), Seq(Row("2"))), null) :: + Row(null, null, null, Seq(Seq(null, Seq(1, 2, 3)))) :: Nil ) } @@ -851,12 +852,12 @@ class JsonSuite extends QueryTest { primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), - (new java.math.BigDecimal("92233720368547758070"), + Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, 21474836470L, - "this is a simple string.") :: Nil + "this is a simple string.") ) val complexJsonSchemaRDD = jsonRDD(complexFieldAndType1) @@ -865,38 +866,38 @@ class JsonSuite extends QueryTest { // Access elements of a primitive array. checkAnswer( sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from complexTable"), - ("str1", "str2", null) :: Nil + Row("str1", "str2", null) ) // Access an array of null values. checkAnswer( sql("select arrayOfNull from complexTable"), - Seq(Seq(null, null, null, null)) :: Nil + Row(Seq(null, null, null, null)) ) // Access elements of a BigInteger array (we use DecimalType internally). checkAnswer( sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from complexTable"), - (new java.math.BigDecimal("922337203685477580700"), - new java.math.BigDecimal("-922337203685477580800"), null) :: Nil + Row(new java.math.BigDecimal("922337203685477580700"), + new java.math.BigDecimal("-922337203685477580800"), null) ) // Access elements of an array of arrays. checkAnswer( sql("select arrayOfArray1[0], arrayOfArray1[1] from complexTable"), - (Seq("1", "2", "3"), Seq("str1", "str2")) :: Nil + Row(Seq("1", "2", "3"), Seq("str1", "str2")) ) // Access elements of an array of arrays. checkAnswer( sql("select arrayOfArray2[0], arrayOfArray2[1] from complexTable"), - (Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) :: Nil + Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) ) // Access elements of an array inside a filed with the type of ArrayType(ArrayType). checkAnswer( sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from complexTable"), - ("str2", 2.1) :: Nil + Row("str2", 2.1) ) // Access a struct and fields inside of it. @@ -911,13 +912,13 @@ class JsonSuite extends QueryTest { // Access an array field of a struct. checkAnswer( sql("select structWithArrayFields.field1, structWithArrayFields.field2 from complexTable"), - (Seq(4, 5, 6), Seq("str1", "str2")) :: Nil + Row(Seq(4, 5, 6), Seq("str1", "str2")) ) // Access elements of an array field of a struct. checkAnswer( sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from complexTable"), - (5, null) :: Nil + Row(5, null) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 4c3a04506ce42..1e7d3e06fc196 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -21,7 +21,7 @@ import parquet.filter2.predicate.Operators._ import parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Literal, Predicate, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, Predicate, Row} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} @@ -40,15 +40,16 @@ import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} class ParquetFilterSuite extends QueryTest with ParquetTest { val sqlContext = TestSQLContext - private def checkFilterPushdown( + private def checkFilterPredicate( rdd: SchemaRDD, - output: Seq[Symbol], predicate: Predicate, filterClass: Class[_ <: FilterPredicate], - checker: (SchemaRDD, Any) => Unit, - expectedResult: => Any): Unit = { + checker: (SchemaRDD, Seq[Row]) => Unit, + expected: Seq[Row]): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { - val query = rdd.select(output.map(_.attr): _*).where(predicate) + val query = rdd.select(output: _*).where(predicate) val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect { case plan: ParquetTableScan => plan.columnPruningPred @@ -58,202 +59,180 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { maybeAnalyzedPredicate.foreach { pred => val maybeFilter = ParquetFilters.createFilter(pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") - maybeFilter.foreach(f => assert(f.getClass === filterClass)) + maybeFilter.foreach { f => + // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) + assert(f.getClass === filterClass) + } } - checker(query, expectedResult) + checker(query, expected) } } - private def checkFilterPushdown - (rdd: SchemaRDD, output: Symbol*) - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate]) - (expectedResult: => Any): Unit = { - checkFilterPushdown(rdd, output, predicate, filterClass, checkAnswer _, expectedResult) + private def checkFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) + (implicit rdd: SchemaRDD): Unit = { + checkFilterPredicate(rdd, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected) } - def checkBinaryFilterPushdown - (rdd: SchemaRDD, output: Symbol*) - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate]) - (expectedResult: => Any): Unit = { - def checkBinaryAnswer(rdd: SchemaRDD, result: Any): Unit = { - val actual = rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq - val expected = result match { - case s: Seq[_] => s.map(_.asInstanceOf[Row].getAs[Array[Byte]](0).mkString(",")) - case s => Seq(s.asInstanceOf[Array[Byte]].mkString(",")) - } - assert(actual.sorted === expected.sorted) - } - checkFilterPushdown(rdd, output, predicate, filterClass, checkBinaryAnswer _, expectedResult) + private def checkFilterPredicate[T] + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: T) + (implicit rdd: SchemaRDD): Unit = { + checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } test("filter pushdown - boolean") { - withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Boolean]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Boolean]]) { - Seq(Row(true), Row(false)) - } + withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) - checkFilterPushdown(rdd, '_1)('_1 === true, classOf[Eq[java.lang.Boolean]])(true) - checkFilterPushdown(rdd, '_1)('_1 !== true, classOf[Operators.NotEq[java.lang.Boolean]]) { - false - } + checkFilterPredicate('_1 === true, classOf[Eq [_]], true) + checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) } } test("filter pushdown - integer") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[Integer]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[Integer]]) { - (1 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[Integer]])(1) - checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[Integer]]) { - (2 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 < 2, classOf[Lt [Integer]])(1) - checkFilterPushdown(rdd, '_1)('_1 > 3, classOf[Gt [Integer]])(4) - checkFilterPushdown(rdd, '_1)('_1 <= 1, classOf[LtEq[Integer]])(1) - checkFilterPushdown(rdd, '_1)('_1 >= 4, classOf[GtEq[Integer]])(4) - - checkFilterPushdown(rdd, '_1)(Literal(1) === '_1, classOf[Eq [Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(2) > '_1, classOf[Lt [Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(3) < '_1, classOf[Gt [Integer]])(4) - checkFilterPushdown(rdd, '_1)(Literal(1) >= '_1, classOf[LtEq[Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(4) <= '_1, classOf[GtEq[Integer]])(4) - - checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[Integer]])(4) - checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { - Seq(Row(1), Row(4)) - } + withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - long") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Long]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Long]]) { - (1 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Long]])(1) - checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Long]]) { - (2 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 < 2, classOf[Lt [java.lang.Long]])(1) - checkFilterPushdown(rdd, '_1)('_1 > 3, classOf[Gt [java.lang.Long]])(4) - checkFilterPushdown(rdd, '_1)('_1 <= 1, classOf[LtEq[java.lang.Long]])(1) - checkFilterPushdown(rdd, '_1)('_1 >= 4, classOf[GtEq[java.lang.Long]])(4) - - checkFilterPushdown(rdd, '_1)(Literal(1) === '_1, classOf[Eq [Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(2) > '_1, classOf[Lt [java.lang.Long]])(1) - checkFilterPushdown(rdd, '_1)(Literal(3) < '_1, classOf[Gt [java.lang.Long]])(4) - checkFilterPushdown(rdd, '_1)(Literal(1) >= '_1, classOf[LtEq[java.lang.Long]])(1) - checkFilterPushdown(rdd, '_1)(Literal(4) <= '_1, classOf[GtEq[java.lang.Long]])(4) - - checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Long]])(4) - checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { - Seq(Row(1), Row(4)) - } + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - float") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Float]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Float]]) { - (1 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Float]])(1) - checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Float]]) { - (2 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 < 2, classOf[Lt [java.lang.Float]])(1) - checkFilterPushdown(rdd, '_1)('_1 > 3, classOf[Gt [java.lang.Float]])(4) - checkFilterPushdown(rdd, '_1)('_1 <= 1, classOf[LtEq[java.lang.Float]])(1) - checkFilterPushdown(rdd, '_1)('_1 >= 4, classOf[GtEq[java.lang.Float]])(4) - - checkFilterPushdown(rdd, '_1)(Literal(1) === '_1, classOf[Eq [Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(2) > '_1, classOf[Lt [java.lang.Float]])(1) - checkFilterPushdown(rdd, '_1)(Literal(3) < '_1, classOf[Gt [java.lang.Float]])(4) - checkFilterPushdown(rdd, '_1)(Literal(1) >= '_1, classOf[LtEq[java.lang.Float]])(1) - checkFilterPushdown(rdd, '_1)(Literal(4) <= '_1, classOf[GtEq[java.lang.Float]])(4) - - checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Float]])(4) - checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { - Seq(Row(1), Row(4)) - } + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - double") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Double]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Double]]) { - (1 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Double]])(1) - checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Double]]) { - (2 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 < 2, classOf[Lt [java.lang.Double]])(1) - checkFilterPushdown(rdd, '_1)('_1 > 3, classOf[Gt [java.lang.Double]])(4) - checkFilterPushdown(rdd, '_1)('_1 <= 1, classOf[LtEq[java.lang.Double]])(1) - checkFilterPushdown(rdd, '_1)('_1 >= 4, classOf[GtEq[java.lang.Double]])(4) - - checkFilterPushdown(rdd, '_1)(Literal(1) === '_1, classOf[Eq[Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(2) > '_1, classOf[Lt [java.lang.Double]])(1) - checkFilterPushdown(rdd, '_1)(Literal(3) < '_1, classOf[Gt [java.lang.Double]])(4) - checkFilterPushdown(rdd, '_1)(Literal(1) >= '_1, classOf[LtEq[java.lang.Double]])(1) - checkFilterPushdown(rdd, '_1)(Literal(4) <= '_1, classOf[GtEq[java.lang.Double]])(4) - - checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Double]])(4) - checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { - Seq(Row(1), Row(4)) - } + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - string") { - withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) { - (1 to 4).map(i => Row.apply(i.toString)) - } - - checkFilterPushdown(rdd, '_1)('_1 === "1", classOf[Eq[String]])("1") - checkFilterPushdown(rdd, '_1)('_1 !== "1", classOf[Operators.NotEq[String]]) { - (2 to 4).map(i => Row.apply(i.toString)) - } + withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate( + '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) + + checkFilterPredicate('_1 === "1", classOf[Eq [_]], "1") + checkFilterPredicate('_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) + + checkFilterPredicate('_1 < "2", classOf[Lt [_]], "1") + checkFilterPredicate('_1 > "3", classOf[Gt [_]], "4") + checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") + checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") + + checkFilterPredicate(Literal("1") === '_1, classOf[Eq [_]], "1") + checkFilterPredicate(Literal("2") > '_1, classOf[Lt [_]], "1") + checkFilterPredicate(Literal("3") < '_1, classOf[Gt [_]], "4") + checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") + checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") + + checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") + checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") + checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) + } + } - checkFilterPushdown(rdd, '_1)('_1 < "2", classOf[Lt [java.lang.String]])("1") - checkFilterPushdown(rdd, '_1)('_1 > "3", classOf[Gt [java.lang.String]])("4") - checkFilterPushdown(rdd, '_1)('_1 <= "1", classOf[LtEq[java.lang.String]])("1") - checkFilterPushdown(rdd, '_1)('_1 >= "4", classOf[GtEq[java.lang.String]])("4") - - checkFilterPushdown(rdd, '_1)(Literal("1") === '_1, classOf[Eq [java.lang.String]])("1") - checkFilterPushdown(rdd, '_1)(Literal("2") > '_1, classOf[Lt [java.lang.String]])("1") - checkFilterPushdown(rdd, '_1)(Literal("3") < '_1, classOf[Gt [java.lang.String]])("4") - checkFilterPushdown(rdd, '_1)(Literal("1") >= '_1, classOf[LtEq[java.lang.String]])("1") - checkFilterPushdown(rdd, '_1)(Literal("4") <= '_1, classOf[GtEq[java.lang.String]])("4") - - checkFilterPushdown(rdd, '_1)(!('_1 < "4"), classOf[Operators.GtEq[java.lang.String]])("4") - checkFilterPushdown(rdd, '_1)('_1 > "2" && '_1 < "4", classOf[Operators.And])("3") - checkFilterPushdown(rdd, '_1)('_1 < "2" || '_1 > "3", classOf[Operators.Or]) { - Seq(Row("1"), Row("4")) + def checkBinaryFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) + (implicit rdd: SchemaRDD): Unit = { + def checkBinaryAnswer(rdd: SchemaRDD, expected: Seq[Row]) = { + assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { + rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted } } + + checkFilterPredicate(rdd, predicate, filterClass, checkBinaryAnswer _, expected) + } + + def checkBinaryFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) + (implicit rdd: SchemaRDD): Unit = { + checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } test("filter pushdown - binary") { @@ -261,33 +240,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { def b: Array[Byte] = int.toString.getBytes("UTF-8") } - withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { rdd => - checkBinaryFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row]) - checkBinaryFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) { - (1 to 4).map(i => Row.apply(i.b)).toSeq - } - - checkBinaryFilterPushdown(rdd, '_1)('_1 === 1.b, classOf[Eq[Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 !== 1.b, classOf[Operators.NotEq[Array[Byte]]]) { - (2 to 4).map(i => Row.apply(i.b)).toSeq - } - - checkBinaryFilterPushdown(rdd, '_1)('_1 < 2.b, classOf[Lt [Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 > 3.b, classOf[Gt [Array[Byte]]])(4.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 <= 1.b, classOf[LtEq[Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 >= 4.b, classOf[GtEq[Array[Byte]]])(4.b) - - checkBinaryFilterPushdown(rdd, '_1)(Literal(1.b) === '_1, classOf[Eq [Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)(Literal(2.b) > '_1, classOf[Lt [Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)(Literal(3.b) < '_1, classOf[Gt [Array[Byte]]])(4.b) - checkBinaryFilterPushdown(rdd, '_1)(Literal(1.b) >= '_1, classOf[LtEq[Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)(Literal(4.b) <= '_1, classOf[GtEq[Array[Byte]]])(4.b) - - checkBinaryFilterPushdown(rdd, '_1)(!('_1 < 4.b), classOf[Operators.GtEq[Array[Byte]]])(4.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 > 2.b && '_1 < 4.b, classOf[Operators.And])(3.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 < 2.b || '_1 > 3.b, classOf[Operators.Or]) { - Seq(Row(1.b), Row(4.b)) - } + withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { implicit rdd => + checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkBinaryFilterPredicate( + '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) + + checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq [_]], 1.b) + checkBinaryFilterPredicate( + '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) + + checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt [_]], 1.b) + checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt [_]], 4.b) + checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) + + checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq [_]], 1.b) + checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt [_]], 1.b) + checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt [_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) + + checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) + checkBinaryFilterPredicate( + '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 973819aaa4d77..a57e4e85a35ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -68,8 +68,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest { /** * Writes `data` to a Parquet file, reads it back and check file contents. */ - protected def checkParquetFile[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { - withParquetRDD(data)(checkAnswer(_, data)) + protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data: Seq[T]): Unit = { + withParquetRDD(data)(r => checkAnswer(r, data.map(Row.fromTuple))) } test("basic data types (without binary)") { @@ -143,7 +143,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { withParquetRDD(data) { rdd => // Structs are converted to `Row`s checkAnswer(rdd, data.map { case Tuple1(struct) => - Tuple1(Row(struct.productIterator.toSeq: _*)) + Row(Row(struct.productIterator.toSeq: _*)) }) } } @@ -153,7 +153,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { withParquetRDD(data) { rdd => // Structs are converted to `Row`s checkAnswer(rdd, data.map { case Tuple1(struct) => - Tuple1(Row(struct.productIterator.toSeq: _*)) + Row(Row(struct.productIterator.toSeq: _*)) }) } } @@ -162,7 +162,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) withParquetRDD(data) { rdd => checkAnswer(rdd, data.map { case Tuple1(m) => - Tuple1(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) + Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) }) } } @@ -261,7 +261,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { val path = new Path(dir.toURI.toString, "part-r-0.parquet") makeRawParquetFile(path) checkAnswer(parquetFile(path.toString), (0 until 10).map { i => - (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 3a073a6b7057e..1263ff818ea19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,1030 +17,72 @@ package org.apache.spark.sql.parquet -import scala.reflect.ClassTag - -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.mapreduce.Job -import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} -import parquet.filter2.predicate.{FilterPredicate, Operators} -import parquet.hadoop.ParquetFileWriter -import parquet.hadoop.util.ContextUtil -import parquet.io.api.Binary - -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils - -case class TestRDDEntry(key: Int, value: String) - -case class NullReflectData( - intField: java.lang.Integer, - longField: java.lang.Long, - floatField: java.lang.Float, - doubleField: java.lang.Double, - booleanField: java.lang.Boolean) - -case class OptionalReflectData( - intField: Option[Int], - longField: Option[Long], - floatField: Option[Float], - doubleField: Option[Double], - booleanField: Option[Boolean]) - -case class Nested(i: Int, s: String) - -case class Data(array: Seq[Int], nested: Nested) - -case class AllDataTypes( - stringField: String, - intField: Int, - longField: Long, - floatField: Float, - doubleField: Double, - shortField: Short, - byteField: Byte, - booleanField: Boolean) - -case class AllDataTypesWithNonPrimitiveType( - stringField: String, - intField: Int, - longField: Long, - floatField: Float, - doubleField: Double, - shortField: Short, - byteField: Byte, - booleanField: Boolean, - array: Seq[Int], - arrayContainsNull: Seq[Option[Int]], - map: Map[Int, Long], - mapValueContainsNull: Map[Int, Option[Long]], - data: Data) - -case class BinaryData(binaryData: Array[Byte]) - -case class NumericData(i: Int, d: Double) -class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { - TestData // Load test data tables. - - private var testRDD: SchemaRDD = null - private val originalParquetFilterPushdownEnabled = TestSQLContext.conf.parquetFilterPushDown - - override def beforeAll() { - ParquetTestData.writeFile() - ParquetTestData.writeFilterFile() - ParquetTestData.writeNestedFile1() - ParquetTestData.writeNestedFile2() - ParquetTestData.writeNestedFile3() - ParquetTestData.writeNestedFile4() - ParquetTestData.writeGlobFiles() - testRDD = parquetFile(ParquetTestData.testDir.toString) - testRDD.registerTempTable("testsource") - parquetFile(ParquetTestData.testFilterDir.toString) - .registerTempTable("testfiltersource") - - setConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED, "true") - } - - override def afterAll() { - Utils.deleteRecursively(ParquetTestData.testDir) - Utils.deleteRecursively(ParquetTestData.testFilterDir) - Utils.deleteRecursively(ParquetTestData.testNestedDir1) - Utils.deleteRecursively(ParquetTestData.testNestedDir2) - Utils.deleteRecursively(ParquetTestData.testNestedDir3) - Utils.deleteRecursively(ParquetTestData.testNestedDir4) - Utils.deleteRecursively(ParquetTestData.testGlobDir) - // here we should also unregister the table?? - - setConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED, originalParquetFilterPushdownEnabled.toString) - } +/** + * A test suite that tests various Parquet queries. + */ +class ParquetQuerySuite extends QueryTest with ParquetTest { + val sqlContext = TestSQLContext - test("Read/Write All Types") { - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - val range = (0 to 255) - val data = sparkContext.parallelize(range).map { x => - parquet.AllDataTypes( - s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0) + test("simple projection") { + withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { + checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) } - - data.saveAsParquetFile(tempDir) - - checkAnswer( - parquetFile(tempDir), - data.toSchemaRDD.collect().toSeq) } - test("read/write binary data") { - // Since equality for Array[Byte] is broken we test this separately. - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil).saveAsParquetFile(tempDir) - parquetFile(tempDir) - .map(r => new String(r(0).asInstanceOf[Array[Byte]], "utf8")) - .collect().toSeq == Seq("test") - } - - ignore("Treat binary as string") { - val oldIsParquetBinaryAsString = TestSQLContext.conf.isParquetBinaryAsString - - // Create the test file. - val file = getTempFilePath("parquet") - val path = file.toString - val range = (0 to 255) - val rowRDD = TestSQLContext.sparkContext.parallelize(range) - .map(i => org.apache.spark.sql.Row(i, s"val_$i".getBytes)) - // We need to ask Parquet to store the String column as a Binary column. - val schema = StructType( - StructField("c1", IntegerType, false) :: - StructField("c2", BinaryType, false) :: Nil) - val schemaRDD1 = applySchema(rowRDD, schema) - schemaRDD1.saveAsParquetFile(path) - checkAnswer( - parquetFile(path).select('c1, 'c2.cast(StringType)), - schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq) - - setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") - parquetFile(path).printSchema() - checkAnswer( - parquetFile(path), - schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq) - - - // Set it back. - TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, oldIsParquetBinaryAsString.toString) - } - - test("Compression options for writing to a Parquetfile") { - val defaultParquetCompressionCodec = TestSQLContext.conf.parquetCompressionCodec - import scala.collection.JavaConversions._ - - val file = getTempFilePath("parquet") - val path = file.toString - val rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - - // test default compression codec - rdd.saveAsParquetFile(path) - var actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) - .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil) - - parquetFile(path).registerTempTable("tmp") - checkAnswer( - sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) - - Utils.deleteRecursively(file) - - // test uncompressed parquet file with property value "UNCOMPRESSED" - TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "UNCOMPRESSED") - - rdd.saveAsParquetFile(path) - actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) - .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil) - - parquetFile(path).registerTempTable("tmp") - checkAnswer( - sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) - - Utils.deleteRecursively(file) - - // test uncompressed parquet file with property value "none" - TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "none") - - rdd.saveAsParquetFile(path) - actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) - .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === "UNCOMPRESSED" :: Nil) - - parquetFile(path).registerTempTable("tmp") - checkAnswer( - sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) - - Utils.deleteRecursively(file) - - // test gzip compression codec - TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "gzip") - - rdd.saveAsParquetFile(path) - actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) - .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil) - - parquetFile(path).registerTempTable("tmp") - checkAnswer( - sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) - - Utils.deleteRecursively(file) - - // test snappy compression codec - TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "snappy") - - rdd.saveAsParquetFile(path) - actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) - .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil) - - parquetFile(path).registerTempTable("tmp") - checkAnswer( - sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) - - Utils.deleteRecursively(file) - - // TODO: Lzo requires additional external setup steps so leave it out for now - // ref.: https://github.com/Parquet/parquet-mr/blob/parquet-1.5.0/parquet-hadoop/src/test/java/parquet/hadoop/example/TestInputOutputFormat.java#L169 - - // Set it back. - TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, defaultParquetCompressionCodec) - } - - test("Read/Write All Types with non-primitive type") { - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - val range = (0 to 255) - val data = sparkContext.parallelize(range).map { x => - parquet.AllDataTypesWithNonPrimitiveType( - s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - (0 until x), - (0 until x).map(Option(_).filter(_ % 3 == 0)), - (0 until x).map(i => i -> i.toLong).toMap, - (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), - parquet.Data((0 until x), parquet.Nested(x, s"$x"))) + test("appending") { + val data = (0 until 10).map(i => (i, i.toString)) + withParquetTable(data, "t") { + sql("INSERT INTO t SELECT * FROM t") + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) } - data.saveAsParquetFile(tempDir) - - checkAnswer( - parquetFile(tempDir), - data.toSchemaRDD.collect().toSeq) } - test("self-join parquet files") { - val x = ParquetTestData.testData.as('x) - val y = ParquetTestData.testData.as('y) - val query = x.join(y).where("x.myint".attr === "y.myint".attr) - - // Check to make sure that the attributes from either side of the join have unique expression - // ids. - query.queryExecution.analyzed.output.filter(_.name == "myint") match { - case Seq(i1, i2) if(i1.exprId == i2.exprId) => - fail(s"Duplicate expression IDs found in query plan: $query") - case Seq(_, _) => // All good + test("self-join") { + // 4 rows, cells of column 1 of row 2 and row 4 are null + val data = (1 to 4).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + (maybeInt, i.toString) } - val result = query.collect() - assert(result.size === 9, "self-join result has incorrect size") - assert(result(0).size === 12, "result row has incorrect size") - result.zipWithIndex.foreach { - case (row, index) => row.zipWithIndex.foreach { - case (field, column) => assert(field != null, s"self-join contains null value in row $index field $column") - } - } - } + withParquetTable(data, "t") { + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") + val queryOutput = selfJoin.queryExecution.analyzed.output - test("Import of simple Parquet file") { - val result = parquetFile(ParquetTestData.testDir.toString).collect() - assert(result.size === 15) - result.zipWithIndex.foreach { - case (row, index) => { - val checkBoolean = - if (index % 3 == 0) - row(0) == true - else - row(0) == false - assert(checkBoolean === true, s"boolean field value in line $index did not match") - if (index % 5 == 0) assert(row(1) === 5, s"int field value in line $index did not match") - assert(row(2) === "abc", s"string field value in line $index did not match") - assert(row(3) === (index.toLong << 33), s"long value in line $index did not match") - assert(row(4) === 2.5F, s"float field value in line $index did not match") - assert(row(5) === 4.5D, s"double field value in line $index did not match") + assertResult(4, s"Field count mismatches")(queryOutput.size) + assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { + queryOutput.filter(_.name == "_1").map(_.exprId).size } - } - } - test("Projection of simple Parquet file") { - val result = ParquetTestData.testData.select('myboolean, 'mylong).collect() - result.zipWithIndex.foreach { - case (row, index) => { - if (index % 3 == 0) - assert(row(0) === true, s"boolean field value in line $index did not match (every third row)") - else - assert(row(0) === false, s"boolean field value in line $index did not match") - assert(row(1) === (index.toLong << 33), s"long field value in line $index did not match") - assert(row.size === 2, s"number of columns in projection in line $index is incorrect") - } + checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) } } - test("Writing metadata from scratch for table CREATE") { - val job = new Job() - val path = new Path(getTempFilePath("testtable").getCanonicalFile.toURI.toString) - val fs: FileSystem = FileSystem.getLocal(ContextUtil.getConfiguration(job)) - ParquetTypesConverter.writeMetaData( - ParquetTestData.testData.output, - path, - TestSQLContext.sparkContext.hadoopConfiguration) - assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val metaData = ParquetTypesConverter.readMetaData(path, Some(ContextUtil.getConfiguration(job))) - assert(metaData != null) - ParquetTestData - .testData - .parquetSchema - .checkContains(metaData.getFileMetaData.getSchema) // throws exception if incompatible - metaData - .getFileMetaData - .getSchema - .checkContains(ParquetTestData.testData.parquetSchema) // throws exception if incompatible - fs.delete(path, true) - } - - test("Creating case class RDD table") { - TestSQLContext.sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - .registerTempTable("tmp") - val rdd = sql("SELECT * FROM tmp").collect().sortBy(_.getInt(0)) - var counter = 1 - rdd.foreach { - // '===' does not like string comparison? - row: Row => { - assert(row.getString(1).equals(s"val_$counter"), s"row $counter value ${row.getString(1)} does not match val_$counter") - counter = counter + 1 - } + test("nested data - struct with array field") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { + case Tuple1((_, Seq(string))) => Row(string) + }) } } - test("Read a parquet file instead of a directory") { - val file = getTempFilePath("parquet") - val path = file.toString - val fsPath = new Path(path) - val fs: FileSystem = fsPath.getFileSystem(TestSQLContext.sparkContext.hadoopConfiguration) - val rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - rdd.coalesce(1).saveAsParquetFile(path) - - val children = fs.listStatus(fsPath).filter(_.getPath.getName.endsWith(".parquet")) - assert(children.length > 0) - val readFile = parquetFile(path + "/" + children(0).getPath.getName) - readFile.registerTempTable("tmpx") - val rdd_copy = sql("SELECT * FROM tmpx").collect() - val rdd_orig = rdd.collect() - for(i <- 0 to 99) { - assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i") - assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value error in line $i") + test("nested data - array of struct") { + val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { + case Tuple1(Seq((_, string))) => Row(string) + }) } - Utils.deleteRecursively(file) - } - - test("Insert (appending) to same table via Scala API") { - sql("INSERT INTO testsource SELECT * FROM testsource") - val double_rdd = sql("SELECT * FROM testsource").collect() - assert(double_rdd != null) - assert(double_rdd.size === 30) - - // let's restore the original test data - Utils.deleteRecursively(ParquetTestData.testDir) - ParquetTestData.writeFile() - } - - test("save and load case class RDD with nulls as parquet") { - val data = parquet.NullReflectData(null, null, null, null, null) - val rdd = sparkContext.parallelize(data :: Nil) - - val file = getTempFilePath("parquet") - val path = file.toString - rdd.saveAsParquetFile(path) - val readFile = parquetFile(path) - - val rdd_saved = readFile.collect() - assert(rdd_saved(0) === Seq.fill(5)(null)) - Utils.deleteRecursively(file) - assert(true) - } - - test("save and load case class RDD with Nones as parquet") { - val data = parquet.OptionalReflectData(None, None, None, None, None) - val rdd = sparkContext.parallelize(data :: Nil) - - val file = getTempFilePath("parquet") - val path = file.toString - rdd.saveAsParquetFile(path) - val readFile = parquetFile(path) - - val rdd_saved = readFile.collect() - assert(rdd_saved(0) === Seq.fill(5)(null)) - Utils.deleteRecursively(file) - assert(true) - } - - test("make RecordFilter for simple predicates") { - def checkFilter[T <: FilterPredicate : ClassTag]( - predicate: Expression, - defined: Boolean = true): Unit = { - val filter = ParquetFilters.createFilter(predicate) - if (defined) { - assert(filter.isDefined) - val tClass = implicitly[ClassTag[T]].runtimeClass - val filterGet = filter.get - assert( - tClass.isInstance(filterGet), - s"$filterGet of type ${filterGet.getClass} is not an instance of $tClass") - } else { - assert(filter.isEmpty) - } - } - - checkFilter[Operators.Eq[Integer]]('a.int === 1) - checkFilter[Operators.Eq[Integer]](Literal(1) === 'a.int) - - checkFilter[Operators.Lt[Integer]]('a.int < 4) - checkFilter[Operators.Lt[Integer]](Literal(4) > 'a.int) - checkFilter[Operators.LtEq[Integer]]('a.int <= 4) - checkFilter[Operators.LtEq[Integer]](Literal(4) >= 'a.int) - - checkFilter[Operators.Gt[Integer]]('a.int > 4) - checkFilter[Operators.Gt[Integer]](Literal(4) < 'a.int) - checkFilter[Operators.GtEq[Integer]]('a.int >= 4) - checkFilter[Operators.GtEq[Integer]](Literal(4) <= 'a.int) - - checkFilter[Operators.And]('a.int === 1 && 'a.int < 4) - checkFilter[Operators.Or]('a.int === 1 || 'a.int < 4) - checkFilter[Operators.NotEq[Integer]](!('a.int === 1)) - - checkFilter('a.int > 'b.int, defined = false) - checkFilter(('a.int > 'b.int) && ('a.int > 'b.int), defined = false) - } - - test("test filter by predicate pushdown") { - for(myval <- Seq("myint", "mylong", "mydouble", "myfloat")) { - val query1 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100") - assert( - query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result1 = query1.collect() - assert(result1.size === 50) - assert(result1(0)(1) === 100) - assert(result1(49)(1) === 149) - val query2 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") - assert( - query2.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result2 = query2.collect() - assert(result2.size === 50) - if (myval == "myint" || myval == "mylong") { - assert(result2(0)(1) === 151) - assert(result2(49)(1) === 200) - } else { - assert(result2(0)(1) === 150) - assert(result2(49)(1) === 199) - } - } - for(myval <- Seq("myint", "mylong")) { - val query3 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190 OR $myval < 10") - assert( - query3.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result3 = query3.collect() - assert(result3.size === 20) - assert(result3(0)(1) === 0) - assert(result3(9)(1) === 9) - assert(result3(10)(1) === 191) - assert(result3(19)(1) === 200) - } - for(myval <- Seq("mydouble", "myfloat")) { - val result4 = - if (myval == "mydouble") { - val query4 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10.0") - assert( - query4.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - query4.collect() - } else { - // CASTs are problematic. Here myfloat will be casted to a double and it seems there is - // currently no way to specify float constants in SqlParser? - sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10").collect() - } - assert(result4.size === 20) - assert(result4(0)(1) === 0) - assert(result4(9)(1) === 9) - assert(result4(10)(1) === 191) - assert(result4(19)(1) === 200) - } - val query5 = sql(s"SELECT * FROM testfiltersource WHERE myboolean = true AND myint < 40") - assert( - query5.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val booleanResult = query5.collect() - assert(booleanResult.size === 10) - for(i <- 0 until 10) { - if (!booleanResult(i).getBoolean(0)) { - fail(s"Boolean value in result row $i not true") - } - if (booleanResult(i).getInt(1) != i * 4) { - fail(s"Int value in result row $i should be ${4*i}") - } - } - val query6 = sql("SELECT * FROM testfiltersource WHERE mystring = \"100\"") - assert( - query6.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val stringResult = query6.collect() - assert(stringResult.size === 1) - assert(stringResult(0).getString(2) == "100", "stringvalue incorrect") - assert(stringResult(0).getInt(1) === 100) - - val query7 = sql(s"SELECT * FROM testfiltersource WHERE myoptint < 40") - assert( - query7.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val optResult = query7.collect() - assert(optResult.size === 20) - for(i <- 0 until 20) { - if (optResult(i)(7) != i * 2) { - fail(s"optional Int value in result row $i should be ${2*4*i}") - } - } - for(myval <- Seq("myoptint", "myoptlong", "myoptdouble", "myoptfloat")) { - val query8 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100") - assert( - query8.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result8 = query8.collect() - assert(result8.size === 25) - assert(result8(0)(7) === 100) - assert(result8(24)(7) === 148) - val query9 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") - assert( - query9.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result9 = query9.collect() - assert(result9.size === 25) - if (myval == "myoptint" || myval == "myoptlong") { - assert(result9(0)(7) === 152) - assert(result9(24)(7) === 200) - } else { - assert(result9(0)(7) === 150) - assert(result9(24)(7) === 198) - } - } - val query10 = sql("SELECT * FROM testfiltersource WHERE myoptstring = \"100\"") - assert( - query10.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result10 = query10.collect() - assert(result10.size === 1) - assert(result10(0).getString(8) == "100", "stringvalue incorrect") - assert(result10(0).getInt(7) === 100) - val query11 = sql(s"SELECT * FROM testfiltersource WHERE myoptboolean = true AND myoptint < 40") - assert( - query11.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result11 = query11.collect() - assert(result11.size === 7) - for(i <- 0 until 6) { - if (!result11(i).getBoolean(6)) { - fail(s"optional Boolean value in result row $i not true") - } - if (result11(i).getInt(7) != i * 6) { - fail(s"optional Int value in result row $i should be ${6*i}") - } - } - - val query12 = sql("SELECT * FROM testfiltersource WHERE mystring >= \"50\"") - assert( - query12.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result12 = query12.collect() - assert(result12.size === 54) - assert(result12(0).getString(2) == "6") - assert(result12(4).getString(2) == "50") - assert(result12(53).getString(2) == "99") - - val query13 = sql("SELECT * FROM testfiltersource WHERE mystring > \"50\"") - assert( - query13.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result13 = query13.collect() - assert(result13.size === 53) - assert(result13(0).getString(2) == "6") - assert(result13(4).getString(2) == "51") - assert(result13(52).getString(2) == "99") - - val query14 = sql("SELECT * FROM testfiltersource WHERE mystring <= \"50\"") - assert( - query14.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result14 = query14.collect() - assert(result14.size === 148) - assert(result14(0).getString(2) == "0") - assert(result14(46).getString(2) == "50") - assert(result14(147).getString(2) == "200") - - val query15 = sql("SELECT * FROM testfiltersource WHERE mystring < \"50\"") - assert( - query15.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result15 = query15.collect() - assert(result15.size === 147) - assert(result15(0).getString(2) == "0") - assert(result15(46).getString(2) == "100") - assert(result15(146).getString(2) == "200") } test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { - val query = sql(s"SELECT mystring FROM testfiltersource WHERE myint < 10") - assert(query.collect().size === 10) - } - - test("Importing nested Parquet file (Addressbook)") { - val result = TestSQLContext - .parquetFile(ParquetTestData.testNestedDir1.toString) - .toSchemaRDD - .collect() - assert(result != null) - assert(result.size === 2) - val first_record = result(0) - val second_record = result(1) - assert(first_record != null) - assert(second_record != null) - assert(first_record.size === 3) - assert(second_record(1) === null) - assert(second_record(2) === null) - assert(second_record(0) === "A. Nonymous") - assert(first_record(0) === "Julien Le Dem") - val first_owner_numbers = first_record(1) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - val first_contacts = first_record(2) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(first_owner_numbers != null) - assert(first_owner_numbers(0) === "555 123 4567") - assert(first_owner_numbers(2) === "XXX XXX XXXX") - assert(first_contacts(0) - .asInstanceOf[CatalystConverter.StructScalaType[_]].size === 2) - val first_contacts_entry_one = first_contacts(0) - .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(first_contacts_entry_one(0) === "Dmitriy Ryaboy") - assert(first_contacts_entry_one(1) === "555 987 6543") - val first_contacts_entry_two = first_contacts(1) - .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(first_contacts_entry_two(0) === "Chris Aniszczyk") - } - - test("Importing nested Parquet file (nested numbers)") { - val result = TestSQLContext - .parquetFile(ParquetTestData.testNestedDir2.toString) - .toSchemaRDD - .collect() - assert(result.size === 1, "number of top-level rows incorrect") - assert(result(0).size === 5, "number of fields in row incorrect") - assert(result(0)(0) === 1) - assert(result(0)(1) === 7) - val subresult1 = result(0)(2).asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult1.size === 3) - assert(subresult1(0) === (1.toLong << 32)) - assert(subresult1(1) === (1.toLong << 33)) - assert(subresult1(2) === (1.toLong << 34)) - val subresult2 = result(0)(3) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(subresult2.size === 2) - assert(subresult2(0) === 2.5) - assert(subresult2(1) === false) - val subresult3 = result(0)(4) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult3.size === 2) - assert(subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 2) - val subresult4 = subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult4(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) - assert(subresult4(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) - assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 1) - assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) - } - - test("Simple query on addressbook") { - val data = TestSQLContext - .parquetFile(ParquetTestData.testNestedDir1.toString) - .toSchemaRDD - val tmp = data.where('owner === "Julien Le Dem").select('owner as 'a, 'contacts as 'c).collect() - assert(tmp.size === 1) - assert(tmp(0)(0) === "Julien Le Dem") - } - - test("Projection in addressbook") { - val data = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD - data.registerTempTable("data") - val query = sql("SELECT owner, contacts[1].name FROM data") - val tmp = query.collect() - assert(tmp.size === 2) - assert(tmp(0).size === 2) - assert(tmp(0)(0) === "Julien Le Dem") - assert(tmp(0)(1) === "Chris Aniszczyk") - assert(tmp(1)(0) === "A. Nonymous") - assert(tmp(1)(1) === null) - } - - test("Simple query on nested int data") { - val data = parquetFile(ParquetTestData.testNestedDir2.toString).toSchemaRDD - data.registerTempTable("data") - val result1 = sql("SELECT entries[0].value FROM data").collect() - assert(result1.size === 1) - assert(result1(0).size === 1) - assert(result1(0)(0) === 2.5) - val result2 = sql("SELECT entries[0] FROM data").collect() - assert(result2.size === 1) - val subresult1 = result2(0)(0).asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(subresult1.size === 2) - assert(subresult1(0) === 2.5) - assert(subresult1(1) === false) - val result3 = sql("SELECT outerouter FROM data").collect() - val subresult2 = result3(0)(0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult2(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) - assert(subresult2(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) - assert(result3(0)(0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](1) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) - } - - test("nested structs") { - val data = parquetFile(ParquetTestData.testNestedDir3.toString) - .toSchemaRDD - data.registerTempTable("data") - val result1 = sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() - assert(result1.size === 1) - assert(result1(0).size === 1) - assert(result1(0)(0) === false) - val result2 = sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() - assert(result2.size === 1) - assert(result2(0).size === 1) - assert(result2(0)(0) === true) - val result3 = sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() - assert(result3.size === 1) - assert(result3(0).size === 1) - assert(result3(0)(0) === false) - } - - test("simple map") { - val data = TestSQLContext - .parquetFile(ParquetTestData.testNestedDir4.toString) - .toSchemaRDD - data.registerTempTable("mapTable") - val result1 = sql("SELECT data1 FROM mapTable").collect() - assert(result1.size === 1) - assert(result1(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, _]] - .getOrElse("key1", 0) === 1) - assert(result1(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, _]] - .getOrElse("key2", 0) === 2) - val result2 = sql("""SELECT data1["key1"] FROM mapTable""").collect() - assert(result2(0)(0) === 1) - } - - test("map with struct values") { - val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD - data.registerTempTable("mapTable") - val result1 = sql("SELECT data2 FROM mapTable").collect() - assert(result1.size === 1) - val entry1 = result1(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] - .getOrElse("seven", null) - assert(entry1 != null) - assert(entry1(0) === 42) - assert(entry1(1) === "the answer") - val entry2 = result1(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] - .getOrElse("eight", null) - assert(entry2 != null) - assert(entry2(0) === 49) - assert(entry2(1) === null) - val result2 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect() - assert(result2.size === 1) - assert(result2(0)(0) === 42.toLong) - assert(result2(0)(1) === "the answer") - } - - test("Writing out Addressbook and reading it back in") { - // TODO: find out why CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME - // has no effect in this test case - val tmpdir = Utils.createTempDir() - Utils.deleteRecursively(tmpdir) - val result = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD - result.saveAsParquetFile(tmpdir.toString) - parquetFile(tmpdir.toString) - .toSchemaRDD - .registerTempTable("tmpcopy") - val tmpdata = sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() - assert(tmpdata.size === 2) - assert(tmpdata(0).size === 2) - assert(tmpdata(0)(0) === "Julien Le Dem") - assert(tmpdata(0)(1) === "Chris Aniszczyk") - assert(tmpdata(1)(0) === "A. Nonymous") - assert(tmpdata(1)(1) === null) - Utils.deleteRecursively(tmpdir) - } - - test("Writing out Map and reading it back in") { - val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD - val tmpdir = Utils.createTempDir() - Utils.deleteRecursively(tmpdir) - data.saveAsParquetFile(tmpdir.toString) - parquetFile(tmpdir.toString) - .toSchemaRDD - .registerTempTable("tmpmapcopy") - val result1 = sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() - assert(result1.size === 1) - assert(result1(0)(0) === 2) - val result2 = sql("SELECT data2 FROM tmpmapcopy").collect() - assert(result2.size === 1) - val entry1 = result2(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] - .getOrElse("seven", null) - assert(entry1 != null) - assert(entry1(0) === 42) - assert(entry1(1) === "the answer") - val entry2 = result2(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] - .getOrElse("eight", null) - assert(entry2 != null) - assert(entry2(0) === 49) - assert(entry2(1) === null) - val result3 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect() - assert(result3.size === 1) - assert(result3(0)(0) === 42.toLong) - assert(result3(0)(1) === "the answer") - Utils.deleteRecursively(tmpdir) - } - - test("read/write fixed-length decimals") { - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - val data = sparkContext.parallelize(0 to 1000) - .map(i => NumericData(i, i / 100.0)) - .select('i, 'd cast DecimalType(precision, scale)) - data.saveAsParquetFile(tempDir) - checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) - } - - // Decimals with precision above 18 are not yet supported - intercept[RuntimeException] { - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - val data = sparkContext.parallelize(0 to 1000) - .map(i => NumericData(i, i / 100.0)) - .select('i, 'd cast DecimalType(19, 10)) - data.saveAsParquetFile(tempDir) - checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) - } - - // Unlimited-length decimals are not yet supported - intercept[RuntimeException] { - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - val data = sparkContext.parallelize(0 to 1000) - .map(i => NumericData(i, i / 100.0)) - .select('i, 'd cast DecimalType.Unlimited) - data.saveAsParquetFile(tempDir) - checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) - } - } - - def checkFilter(predicate: Predicate, filterClass: Class[_ <: FilterPredicate]): Unit = { - val filter = ParquetFilters.createFilter(predicate) - assert(filter.isDefined) - assert(filter.get.getClass == filterClass) - } - - test("Pushdown IsNull predicate") { - checkFilter('a.int.isNull, classOf[Operators.Eq[Integer]]) - checkFilter('a.long.isNull, classOf[Operators.Eq[java.lang.Long]]) - checkFilter('a.float.isNull, classOf[Operators.Eq[java.lang.Float]]) - checkFilter('a.double.isNull, classOf[Operators.Eq[java.lang.Double]]) - checkFilter('a.string.isNull, classOf[Operators.Eq[Binary]]) - checkFilter('a.binary.isNull, classOf[Operators.Eq[Binary]]) - } - - test("Pushdown IsNotNull predicate") { - checkFilter('a.int.isNotNull, classOf[Operators.NotEq[Integer]]) - checkFilter('a.long.isNotNull, classOf[Operators.NotEq[java.lang.Long]]) - checkFilter('a.float.isNotNull, classOf[Operators.NotEq[java.lang.Float]]) - checkFilter('a.double.isNotNull, classOf[Operators.NotEq[java.lang.Double]]) - checkFilter('a.string.isNotNull, classOf[Operators.NotEq[Binary]]) - checkFilter('a.binary.isNotNull, classOf[Operators.NotEq[Binary]]) - } - - test("Pushdown EqualTo predicate") { - checkFilter('a.int === 0, classOf[Operators.Eq[Integer]]) - checkFilter('a.long === 0.toLong, classOf[Operators.Eq[java.lang.Long]]) - checkFilter('a.float === 0.toFloat, classOf[Operators.Eq[java.lang.Float]]) - checkFilter('a.double === 0.toDouble, classOf[Operators.Eq[java.lang.Double]]) - checkFilter('a.string === "foo", classOf[Operators.Eq[Binary]]) - checkFilter('a.binary === "foo".getBytes, classOf[Operators.Eq[Binary]]) - } - - test("Pushdown Not(EqualTo) predicate") { - checkFilter(!('a.int === 0), classOf[Operators.NotEq[Integer]]) - checkFilter(!('a.long === 0.toLong), classOf[Operators.NotEq[java.lang.Long]]) - checkFilter(!('a.float === 0.toFloat), classOf[Operators.NotEq[java.lang.Float]]) - checkFilter(!('a.double === 0.toDouble), classOf[Operators.NotEq[java.lang.Double]]) - checkFilter(!('a.string === "foo"), classOf[Operators.NotEq[Binary]]) - checkFilter(!('a.binary === "foo".getBytes), classOf[Operators.NotEq[Binary]]) - } - - test("Pushdown LessThan predicate") { - checkFilter('a.int < 0, classOf[Operators.Lt[Integer]]) - checkFilter('a.long < 0.toLong, classOf[Operators.Lt[java.lang.Long]]) - checkFilter('a.float < 0.toFloat, classOf[Operators.Lt[java.lang.Float]]) - checkFilter('a.double < 0.toDouble, classOf[Operators.Lt[java.lang.Double]]) - checkFilter('a.string < "foo", classOf[Operators.Lt[Binary]]) - checkFilter('a.binary < "foo".getBytes, classOf[Operators.Lt[Binary]]) - } - - test("Pushdown LessThanOrEqual predicate") { - checkFilter('a.int <= 0, classOf[Operators.LtEq[Integer]]) - checkFilter('a.long <= 0.toLong, classOf[Operators.LtEq[java.lang.Long]]) - checkFilter('a.float <= 0.toFloat, classOf[Operators.LtEq[java.lang.Float]]) - checkFilter('a.double <= 0.toDouble, classOf[Operators.LtEq[java.lang.Double]]) - checkFilter('a.string <= "foo", classOf[Operators.LtEq[Binary]]) - checkFilter('a.binary <= "foo".getBytes, classOf[Operators.LtEq[Binary]]) - } - - test("Pushdown GreaterThan predicate") { - checkFilter('a.int > 0, classOf[Operators.Gt[Integer]]) - checkFilter('a.long > 0.toLong, classOf[Operators.Gt[java.lang.Long]]) - checkFilter('a.float > 0.toFloat, classOf[Operators.Gt[java.lang.Float]]) - checkFilter('a.double > 0.toDouble, classOf[Operators.Gt[java.lang.Double]]) - checkFilter('a.string > "foo", classOf[Operators.Gt[Binary]]) - checkFilter('a.binary > "foo".getBytes, classOf[Operators.Gt[Binary]]) - } - - test("Pushdown GreaterThanOrEqual predicate") { - checkFilter('a.int >= 0, classOf[Operators.GtEq[Integer]]) - checkFilter('a.long >= 0.toLong, classOf[Operators.GtEq[java.lang.Long]]) - checkFilter('a.float >= 0.toFloat, classOf[Operators.GtEq[java.lang.Float]]) - checkFilter('a.double >= 0.toDouble, classOf[Operators.GtEq[java.lang.Double]]) - checkFilter('a.string >= "foo", classOf[Operators.GtEq[Binary]]) - checkFilter('a.binary >= "foo".getBytes, classOf[Operators.GtEq[Binary]]) - } - - test("Comparison with null should not be pushed down") { - val predicates = Seq( - 'a.int === null, - !('a.int === null), - - Literal(null) === 'a.int, - !(Literal(null) === 'a.int), - - 'a.int < null, - 'a.int <= null, - 'a.int > null, - 'a.int >= null, - - Literal(null) < 'a.int, - Literal(null) <= 'a.int, - Literal(null) > 'a.int, - Literal(null) >= 'a.int - ) - - predicates.foreach { p => - assert( - ParquetFilters.createFilter(p).isEmpty, - "Comparison predicate with null shouldn't be pushed down") - } - } - - test("Import of simple Parquet files using glob wildcard pattern") { - val testGlobDir = ParquetTestData.testGlobDir.toString - val globPatterns = Array(testGlobDir + "/*/*", testGlobDir + "/spark-*/*", testGlobDir + "/?pa?k-*/*") - globPatterns.foreach { path => - val result = parquetFile(path).collect() - assert(result.size === 45) - result.zipWithIndex.foreach { - case (row, index) => { - val checkBoolean = - if ((index % 15) % 3 == 0) - row(0) == true - else - row(0) == false - assert(checkBoolean === true, s"boolean field value in line $index did not match") - if ((index % 15) % 5 == 0) assert(row(1) === 5, s"int field value in line $index did not match") - assert(row(2) === "abc", s"string field value in line $index did not match") - assert(row(3) === ((index.toLong % 15) << 33), s"long value in line $index did not match") - assert(row(4) === 2.5F, s"float field value in line $index did not match") - assert(row(5) === 4.5D, s"double field value in line $index did not match") - } - } + withParquetTable((1 to 10).map(Tuple1.apply), "t") { + checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala deleted file mode 100644 index 4c081fb4510b2..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ - -/** - * A test suite that tests various Parquet queries. - */ -class ParquetQuerySuite2 extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext - - test("simple projection") { - withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { - checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) - } - } - - test("appending") { - val data = (0 until 10).map(i => (i, i.toString)) - withParquetTable(data, "t") { - sql("INSERT INTO t SELECT * FROM t") - checkAnswer(table("t"), data ++ data) - } - } - - test("self-join") { - // 4 rows, cells of column 1 of row 2 and row 4 are null - val data = (1 to 4).map { i => - val maybeInt = if (i % 2 == 0) None else Some(i) - (maybeInt, i.toString) - } - - withParquetTable(data, "t") { - val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") - val queryOutput = selfJoin.queryExecution.analyzed.output - - assertResult(4, s"Field count mismatches")(queryOutput.size) - assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { - queryOutput.filter(_.name == "_1").map(_.exprId).size - } - - checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) - } - } - - test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { - case Tuple1((_, Seq(string))) => Row(string) - }) - } - } - - test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { - case Tuple1(Seq((_, string))) => Row(string) - }) - } - } - - test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { - withParquetTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 264f6d94c4ed9..b1e0919b7aed1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -244,7 +244,7 @@ class TableScanSuite extends DataSourceTest { sqlTest( "SELECT count(*) FROM tableWithSchema", - 10) + Seq(Row(10))) sqlTest( "SELECT `string$%Field` FROM tableWithSchema", @@ -260,7 +260,7 @@ class TableScanSuite extends DataSourceTest { sqlTest( "SELECT structFieldSimple.key, arrayFieldSimple[1] FROM tableWithSchema a where int_Field=1", - Seq(Seq(1, 2))) + Seq(Row(1, 2))) sqlTest( "SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index ebf7003ff9e57..3f20c6142e59a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -20,30 +20,20 @@ package org.apache.spark.sql.hive import scala.language.implicitConversions import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, SqlLexical} +import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.hive.execution.{AddJar, AddFile, HiveNativeCommand} /** * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions. */ private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` + // properties via reflection the class in runtime for constructing the SqlLexical object protected val ADD = Keyword("ADD") protected val DFS = Keyword("DFS") protected val FILE = Keyword("FILE") protected val JAR = Keyword("JAR") - private val reservedWords = - this - .getClass - .getMethods - .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].str) - - override val lexical = new SqlLexical(reservedWords) - protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl protected lazy val hiveQl: Parser[LogicalPlan] = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 10833c113216a..9d2cfd8e0d669 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.processors._ +import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} @@ -66,11 +67,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { new this.QueryExecution { val logical = plan } override def sql(sqlText: String): SchemaRDD = { + val substituted = new VariableSubstitution().substitute(hiveconf, sqlText) // TODO: Create a framework for registering parsers instead of just hardcoding if statements. if (conf.dialect == "sql") { - super.sql(sqlText) + super.sql(substituted) } else if (conf.dialect == "hiveql") { - new SchemaRDD(this, ddlParser(sqlText).getOrElse(HiveQl.parseSql(sqlText))) + new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted))) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") } @@ -368,10 +370,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { .mkString("\t") } case command: ExecutedCommand => - command.executeCollect().map(_.head.toString) + command.executeCollect().map(_(0).toString) case other => - val result: Seq[Seq[Any]] = other.executeCollect().toSeq + val result: Seq[Seq[Any]] = other.executeCollect().map(_.toSeq).toSeq // We need the types so we can output struct field names val types = analyzed.output.map(_.dataType) // Reformat to match hive tab delimited output. @@ -395,7 +397,7 @@ private object HiveContext { protected[sql] def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => - struct.zip(fields).map { + struct.toSeq.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => @@ -418,7 +420,7 @@ private object HiveContext { /** Hive outputs fields of structs slightly differently than top level attributes. */ protected def toHiveStructString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => - struct.zip(fields).map { + struct.toSeq.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index eeabfdd857916..82dba99900df9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -348,7 +348,7 @@ private[hive] trait HiveInspectors { (o: Any) => { if (o != null) { val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row].toSeq).zipped.foreach { (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) } struct @@ -432,7 +432,7 @@ private[hive] trait HiveInspectors { } case x: SettableStructObjectInspector => val fieldRefs = x.getAllStructFieldRefs - val row = a.asInstanceOf[Seq[_]] + val row = a.asInstanceOf[Row] // 1. create the pojo (most likely) object val result = x.create() var i = 0 @@ -448,7 +448,7 @@ private[hive] trait HiveInspectors { result case x: StructObjectInspector => val fieldRefs = x.getAllStructFieldRefs - val row = a.asInstanceOf[Seq[_]] + val row = a.asInstanceOf[Row] val result = new java.util.ArrayList[AnyRef](fieldRefs.length) var i = 0 while (i < fieldRefs.length) { @@ -475,7 +475,7 @@ private[hive] trait HiveInspectors { } def wrap( - row: Seq[Any], + row: Row, inspectors: Seq[ObjectInspector], cache: Array[AnyRef]): Array[AnyRef] = { var i = 0 @@ -486,6 +486,18 @@ private[hive] trait HiveInspectors { cache } + def wrap( + row: Seq[Any], + inspectors: Seq[ObjectInspector], + cache: Array[AnyRef]): Array[AnyRef] = { + var i = 0 + while (i < inspectors.length) { + cache(i) = wrap(row(i), inspectors(i)) + i += 1 + } + cache + } + /** * @param dataType Catalyst data type * @return Hive java object inspector (recursively), not the Writable ObjectInspector diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index d898b876c39f8..76d2140372197 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -360,7 +360,7 @@ private[hive] case class HiveUdafFunction( protected lazy val cached = new Array[AnyRef](exprs.length) def update(input: Row): Unit = { - val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray + val inputs = inputProjection(input) function.iterate(buffer, wrap(inputs, inspectors, cached)) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index cc8bb3e172c6e..aae175e426ade 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -209,7 +209,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = { val dynamicPartPath = dynamicPartColNames - .zip(row.takeRight(dynamicPartColNames.length)) + .zip(row.toSeq.takeRight(dynamicPartColNames.length)) .map { case (col, rawVal) => val string = if (rawVal == null) null else String.valueOf(rawVal) s"/$col=${if (string == null || string.isEmpty) defaultPartName else string}" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala index f89c49d292c6c..f320d732fb77a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util._ * So, we duplicate this code here. */ class QueryTest extends PlanTest { + /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer @@ -56,17 +57,20 @@ class QueryTest extends PlanTest { * @param rdd the [[SchemaRDD]] to be executed * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. */ - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Any): Unit = { - val convertedAnswer = expectedAnswer match { - case s: Seq[_] if s.isEmpty => s - case s: Seq[_] if s.head.isInstanceOf[Product] && - !s.head.isInstanceOf[Seq[_]] => s.map(_.asInstanceOf[Product].productIterator.toIndexedSeq) - case s: Seq[_] => s - case singleItem => Seq(Seq(singleItem)) + protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = { + val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case o => o + }) + } + if (!isSorted) converted.sortBy(_.toString) else converted } - - val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s}.nonEmpty - def prepareAnswer(answer: Seq[Any]) = if (!isSorted) answer.sortBy(_.toString) else answer val sparkAnswer = try rdd.collect().toSeq catch { case e: Exception => fail( @@ -74,11 +78,12 @@ class QueryTest extends PlanTest { |Exception thrown while executing query: |${rdd.queryExecution} |== Exception == - |${stackTraceToString(e)} + |$e + |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} """.stripMargin) } - if(prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) { + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { fail(s""" |Results do not match for query: |${rdd.logicalPlan} @@ -88,11 +93,22 @@ class QueryTest extends PlanTest { |${rdd.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${convertedAnswer.size} ==" +: - prepareAnswer(convertedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} """.stripMargin) } } + + protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = { + checkAnswer(rdd, Seq(expectedAnswer)) + } + + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { + test(sqlString) { + checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + } + } + } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 4864607252034..2d3ff680125ad 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -129,6 +129,12 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { } } + def checkValues(row1: Seq[Any], row2: Row): Unit = { + row1.zip(row2.toSeq).map { + case (r1, r2) => checkValue(r1, r2) + } + } + def checkValue(v1: Any, v2: Any): Unit = { (v1, v2) match { case (r1: Decimal, r2: Decimal) => @@ -198,7 +204,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { case (t, idx) => StructField(s"c_$idx", t) }) - checkValues(row, unwrap(wrap(row, toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) + checkValues(row, unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 7cfb875e05db3..0e6636d38ed3c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -43,7 +43,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the table has also been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq + testData.collect().toSeq.map(Row.fromTuple) ) // Add more data. @@ -52,7 +52,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the table has been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq ++ testData.collect().toSeq + testData.toSchemaRDD.collect().toSeq ++ testData.toSchemaRDD.collect().toSeq ) // Now overwrite. @@ -61,7 +61,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the registered table has also been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq + testData.collect().toSeq.map(Row.fromTuple) ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 53d8aa7739bc2..7408c7ffd69e8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -155,7 +155,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), - ("a", "b") :: Nil) + Row("a", "b")) FileUtils.deleteDirectory(tempDir) sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) @@ -164,14 +164,14 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { // will show. checkAnswer( sql("SELECT * FROM jsonTable"), - ("a1", "b1") :: Nil) + Row("a1", "b1")) refreshTable("jsonTable") // Check that the refresh worked checkAnswer( sql("SELECT * FROM jsonTable"), - ("a1", "b1", "c1") :: Nil) + Row("a1", "b1", "c1")) FileUtils.deleteDirectory(tempDir) } @@ -191,7 +191,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), - ("a", "b") :: Nil) + Row("a", "b")) FileUtils.deleteDirectory(tempDir) sparkContext.parallelize(("a", "b", "c") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) @@ -210,7 +210,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { // New table should reflect new schema. checkAnswer( sql("SELECT * FROM jsonTable"), - ("a", "b", "c") :: Nil) + Row("a", "b", "c")) FileUtils.deleteDirectory(tempDir) } @@ -253,6 +253,6 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |) """.stripMargin) - sql("DROP TABLE jsonTable").collect.foreach(println) + sql("DROP TABLE jsonTable").collect().foreach(println) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 0b4e76c9d3d2f..6f07fd5a879c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag -import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -141,7 +141,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { before: () => Unit, after: () => Unit, query: String, - expectedAnswer: Seq[Any], + expectedAnswer: Seq[Row], ct: ClassTag[_]) = { before() @@ -183,7 +183,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { /** Tests for MetastoreRelation */ val metastoreQuery = """SELECT * FROM src a JOIN src b ON a.key = 238 AND a.key = b.key""" - val metastoreAnswer = Seq.fill(4)((238, "val_238", 238, "val_238")) + val metastoreAnswer = Seq.fill(4)(Row(238, "val_238", 238, "val_238")) mkTest( () => (), () => (), @@ -197,7 +197,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val leftSemiJoinQuery = """SELECT * FROM src a |left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin - val answer = (86, "val_86") :: Nil + val answer = Row(86, "val_86") var rdd = sql(leftSemiJoinQuery) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index c14f0d24e0dc3..df72be7746ac6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -226,7 +226,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Jdk version leads to different query output for double, so not use createQueryTest here test("division") { val res = sql("SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1").collect().head - Seq(2.0, 0.5, 0.3333333333333333, 0.002).zip(res).foreach( x => + Seq(2.0, 0.5, 0.3333333333333333, 0.002).zip(res.toSeq).foreach( x => assert(x._1 == x._2.asInstanceOf[Double])) } @@ -235,7 +235,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("Query expressed in SQL") { setConf("spark.sql.dialect", "sql") - assert(sql("SELECT 1").collect() === Array(Seq(1))) + assert(sql("SELECT 1").collect() === Array(Row(1))) setConf("spark.sql.dialect", "hiveql") } @@ -467,7 +467,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TestData(2, "str2") :: Nil) testData.registerTempTable("REGisteredTABle") - assertResult(Array(Array(2, "str2"))) { + assertResult(Array(Row(2, "str2"))) { sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + "WHERE TableAliaS.a > 1").collect() } @@ -553,12 +553,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Describe a table assertResult( Array( - Array("key", "int", null), - Array("value", "string", null), - Array("dt", "string", null), - Array("# Partition Information", "", ""), - Array("# col_name", "data_type", "comment"), - Array("dt", "string", null)) + Row("key", "int", null), + Row("value", "string", null), + Row("dt", "string", null), + Row("# Partition Information", "", ""), + Row("# col_name", "data_type", "comment"), + Row("dt", "string", null)) ) { sql("DESCRIBE test_describe_commands1") .select('col_name, 'data_type, 'comment) @@ -568,12 +568,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Describe a table with a fully qualified table name assertResult( Array( - Array("key", "int", null), - Array("value", "string", null), - Array("dt", "string", null), - Array("# Partition Information", "", ""), - Array("# col_name", "data_type", "comment"), - Array("dt", "string", null)) + Row("key", "int", null), + Row("value", "string", null), + Row("dt", "string", null), + Row("# Partition Information", "", ""), + Row("# col_name", "data_type", "comment"), + Row("dt", "string", null)) ) { sql("DESCRIBE default.test_describe_commands1") .select('col_name, 'data_type, 'comment) @@ -623,8 +623,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assertResult( Array( - Array("a", "IntegerType", null), - Array("b", "StringType", null)) + Row("a", "IntegerType", null), + Row("b", "StringType", null)) ) { sql("DESCRIBE test_describe_commands2") .select('col_name, 'data_type, 'comment) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index 5dafcd6c0a76a..f2374a215291b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -64,7 +64,7 @@ class HiveUdfSuite extends QueryTest { test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") { checkAnswer( sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"), - 8 + Row(8) ) } @@ -115,7 +115,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'") checkAnswer( sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(), - Seq(Seq("1"), Seq("2"))) + Seq(Row("1"), Row("2"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") TestHive.reset() @@ -131,7 +131,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") checkAnswer( sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(), - Seq(Seq(0), Seq(2), Seq(13))) + Seq(Row(0), Row(2), Row(13))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") TestHive.reset() @@ -146,7 +146,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") checkAnswer( sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(), - Seq(Seq("a,b,c"), Seq("d,e"))) + Seq(Row("a,b,c"), Row("d,e"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") TestHive.reset() @@ -160,7 +160,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") checkAnswer( sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(), - Seq(Seq("hello world"), Seq("hello goodbye"))) + Seq(Row("hello world"), Row("hello goodbye"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") TestHive.reset() @@ -177,7 +177,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") checkAnswer( sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(), - Seq(Seq("0, 0"), Seq("2, 2"), Seq("13, 13"))) + Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") TestHive.reset() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index d41eb9e870bf0..7f9f1ac7cd80d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -41,7 +41,7 @@ class SQLQuerySuite extends QueryTest { } test("CTAS with serde") { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() sql( """CREATE TABLE ctas2 | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" @@ -51,23 +51,23 @@ class SQLQuerySuite extends QueryTest { | AS | SELECT key, value | FROM src - | ORDER BY key, value""".stripMargin).collect + | ORDER BY key, value""".stripMargin).collect() sql( """CREATE TABLE ctas3 | ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\012' | STORED AS textfile AS | SELECT key, value | FROM src - | ORDER BY key, value""".stripMargin).collect + | ORDER BY key, value""".stripMargin).collect() // the table schema may like (key: integer, value: string) sql( """CREATE TABLE IF NOT EXISTS ctas4 AS - | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect + | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect() // do nothing cause the table ctas4 already existed. sql( """CREATE TABLE IF NOT EXISTS ctas4 AS - | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect + | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() checkAnswer( sql("SELECT k, value FROM ctas1 ORDER BY k, value"), @@ -89,7 +89,7 @@ class SQLQuerySuite extends QueryTest { intercept[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] { sql( """CREATE TABLE ctas4 AS - | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect + | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() } checkAnswer( sql("SELECT key, value FROM ctas4 ORDER BY key, value"), @@ -104,6 +104,24 @@ class SQLQuerySuite extends QueryTest { ) } + test("command substitution") { + sql("set tbl=src") + checkAnswer( + sql("SELECT key FROM ${hiveconf:tbl} ORDER BY key, value limit 1"), + sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) + + sql("set hive.variable.substitute=false") // disable the substitution + sql("set tbl2=src") + intercept[Exception] { + sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1").collect() + } + + sql("set hive.variable.substitute=true") // enable the substitution + checkAnswer( + sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1"), + sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) + } + test("ordering not in select") { checkAnswer( sql("SELECT key FROM src ORDER BY value"), @@ -126,7 +144,7 @@ class SQLQuerySuite extends QueryTest { sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested") checkAnswer( sql("SELECT f1.f2.f3 FROM nested"), - 1) + Row(1)) checkAnswer(sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested"), Seq.empty[Row]) checkAnswer( @@ -233,7 +251,7 @@ class SQLQuerySuite extends QueryTest { | (s struct, | innerArray:array, | innerMap: map>) - """.stripMargin).collect + """.stripMargin).collect() sql( """ @@ -243,7 +261,7 @@ class SQLQuerySuite extends QueryTest { checkAnswer( sql("SELECT * FROM nullValuesInInnerComplexTypes"), - Seq(Seq(Seq(null, null, null))) + Row(Row(null, null, null)) ) sql("DROP TABLE nullValuesInInnerComplexTypes") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 4bc14bad0ad5f..581f666399492 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -39,7 +39,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test("SELECT on Parquet table") { val data = (1 to 4).map(i => (i, s"val_$i")) withParquetTable(data, "t") { - checkAnswer(sql("SELECT * FROM t"), data) + checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala index 8bbb7f2fdbf48..79fd99d9f89ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala @@ -177,81 +177,81 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll test(s"ordering of the partitioning columns $table") { checkAnswer( sql(s"SELECT p, stringField FROM $table WHERE p = 1"), - Seq.fill(10)((1, "part-1")) + Seq.fill(10)(Row(1, "part-1")) ) checkAnswer( sql(s"SELECT stringField, p FROM $table WHERE p = 1"), - Seq.fill(10)(("part-1", 1)) + Seq.fill(10)(Row("part-1", 1)) ) } test(s"project the partitioning column $table") { checkAnswer( sql(s"SELECT p, count(*) FROM $table group by p"), - (1, 10) :: - (2, 10) :: - (3, 10) :: - (4, 10) :: - (5, 10) :: - (6, 10) :: - (7, 10) :: - (8, 10) :: - (9, 10) :: - (10, 10) :: Nil + Row(1, 10) :: + Row(2, 10) :: + Row(3, 10) :: + Row(4, 10) :: + Row(5, 10) :: + Row(6, 10) :: + Row(7, 10) :: + Row(8, 10) :: + Row(9, 10) :: + Row(10, 10) :: Nil ) } test(s"project partitioning and non-partitioning columns $table") { checkAnswer( sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), - ("part-1", 1, 10) :: - ("part-2", 2, 10) :: - ("part-3", 3, 10) :: - ("part-4", 4, 10) :: - ("part-5", 5, 10) :: - ("part-6", 6, 10) :: - ("part-7", 7, 10) :: - ("part-8", 8, 10) :: - ("part-9", 9, 10) :: - ("part-10", 10, 10) :: Nil + Row("part-1", 1, 10) :: + Row("part-2", 2, 10) :: + Row("part-3", 3, 10) :: + Row("part-4", 4, 10) :: + Row("part-5", 5, 10) :: + Row("part-6", 6, 10) :: + Row("part-7", 7, 10) :: + Row("part-8", 8, 10) :: + Row("part-9", 9, 10) :: + Row("part-10", 10, 10) :: Nil ) } test(s"simple count $table") { checkAnswer( sql(s"SELECT COUNT(*) FROM $table"), - 100) + Row(100)) } test(s"pruned count $table") { checkAnswer( sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), - 10) + Row(10)) } test(s"non-existant partition $table") { checkAnswer( sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), - 0) + Row(0)) } test(s"multi-partition pruned count $table") { checkAnswer( sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), - 30) + Row(30)) } test(s"non-partition predicates $table") { checkAnswer( sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), - 30) + Row(30)) } test(s"sum $table") { checkAnswer( sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), - 1 + 2 + 3) + Row(1 + 2 + 3)) } test(s"hive udfs $table") { @@ -266,6 +266,6 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll test("non-part select(*)") { checkAnswer( sql("SELECT COUNT(*) FROM normal_parquet"), - 10) + Row(10)) } } diff --git a/streaming/pom.xml b/streaming/pom.xml index d3c6d0347a622..22b0d714b57f6 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -96,5 +96,13 @@ + + + ../python + + pyspark/streaming/*.py + + + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index e59c24adb84af..0e285d6088ec1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -160,6 +160,14 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { } } + /** + * Get the maximum remember duration across all the input streams. This is a conservative but + * safe remember duration which can be used to perform cleanup operations. + */ + def getMaxInputStreamRememberDuration(): Duration = { + inputStreams.map { _.rememberDuration }.maxBy { _.milliseconds } + } + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { logDebug("DStreamGraph.writeObject used") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index d8695b8e05962..9a2254bcdc1f7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -17,14 +17,15 @@ package org.apache.spark.streaming.api.java +import java.lang.{Boolean => JBoolean} +import java.io.{Closeable, InputStream} +import java.util.{List => JList, Map => JMap} import scala.collection.JavaConversions._ import scala.reflect.ClassTag -import java.io.{Closeable, InputStream} -import java.util.{List => JList, Map => JMap} - import akka.actor.{Props, SupervisorStrategy} +import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark.{SparkConf, SparkContext} @@ -250,21 +251,53 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Files must be written to the monitored directory by "moving" them from another * location within the same file system. File names starting with . are ignored. * @param directory HDFS directory to monitor for new file + * @param kClass class of key for reading HDFS file + * @param vClass class of value for reading HDFS file + * @param fClass class of input format for reading HDFS file * @tparam K Key type for reading HDFS file * @tparam V Value type for reading HDFS file * @tparam F Input format for reading HDFS file */ def fileStream[K, V, F <: NewInputFormat[K, V]]( - directory: String): JavaPairInputDStream[K, V] = { - implicit val cmk: ClassTag[K] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] - implicit val cmv: ClassTag[V] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] - implicit val cmf: ClassTag[F] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[F]] + directory: String, + kClass: Class[K], + vClass: Class[V], + fClass: Class[F]): JavaPairInputDStream[K, V] = { + implicit val cmk: ClassTag[K] = ClassTag(kClass) + implicit val cmv: ClassTag[V] = ClassTag(vClass) + implicit val cmf: ClassTag[F] = ClassTag(fClass) ssc.fileStream[K, V, F](directory) } + /** + * Create an input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them using the given key-value types and input format. + * Files must be written to the monitored directory by "moving" them from another + * location within the same file system. File names starting with . are ignored. + * @param directory HDFS directory to monitor for new file + * @param kClass class of key for reading HDFS file + * @param vClass class of value for reading HDFS file + * @param fClass class of input format for reading HDFS file + * @param filter Function to filter paths to process + * @param newFilesOnly Should process only new files and ignore existing files in the directory + * @tparam K Key type for reading HDFS file + * @tparam V Value type for reading HDFS file + * @tparam F Input format for reading HDFS file + */ + def fileStream[K, V, F <: NewInputFormat[K, V]]( + directory: String, + kClass: Class[K], + vClass: Class[V], + fClass: Class[F], + filter: JFunction[Path, JBoolean], + newFilesOnly: Boolean): JavaPairInputDStream[K, V] = { + implicit val cmk: ClassTag[K] = ClassTag(kClass) + implicit val cmv: ClassTag[V] = ClassTag(vClass) + implicit val cmf: ClassTag[F] = ClassTag(fClass) + def fn = (x: Path) => filter.call(x).booleanValue() + ssc.fileStream[K, V, F](directory, fn, newFilesOnly) + } + /** * Create an input stream with any arbitrary user implemented actor receiver. * @param props Props object defining creation of the actor diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index afd3c4bc4c4fe..8be04314c4285 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -94,15 +94,4 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont } Some(blockRDD) } - - /** - * Clear metadata that are older than `rememberDuration` of this DStream. - * This is an internal method that should not be called directly. This - * implementation overrides the default implementation to clear received - * block information. - */ - private[streaming] override def clearMetadata(time: Time) { - super.clearMetadata(time) - ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration) - } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala index ab9fa192191aa..7bf3c33319491 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala @@ -17,7 +17,10 @@ package org.apache.spark.streaming.receiver -/** Messages sent to the NetworkReceiver. */ +import org.apache.spark.streaming.Time + +/** Messages sent to the Receiver. */ private[streaming] sealed trait ReceiverMessage extends Serializable private[streaming] object StopReceiver extends ReceiverMessage +private[streaming] case class CleanupOldBlocks(threshTime: Time) extends ReceiverMessage diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index d7229c2b96d0b..716cf2c7f32fc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -82,6 +83,9 @@ private[streaming] class ReceiverSupervisorImpl( case StopReceiver => logInfo("Received stop signal") stop("Stopped by driver", None) + case CleanupOldBlocks(threshTime) => + logDebug("Received delete old batch signal") + cleanupOldBlocks(threshTime) } def ref = self @@ -193,4 +197,9 @@ private[streaming] class ReceiverSupervisorImpl( /** Generate new block ID */ private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement) + + private def cleanupOldBlocks(cleanupThreshTime: Time): Unit = { + logDebug(s"Cleaning up blocks older then $cleanupThreshTime") + receivedBlockHandler.cleanupOldBlocks(cleanupThreshTime.milliseconds) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 39b66e1130768..d86f852aba97e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -238,13 +238,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream metadata for the given `time`. */ private def clearMetadata(time: Time) { ssc.graph.clearMetadata(time) - jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration) // If checkpointing is enabled, then checkpoint, // else mark batch to be fully processed if (shouldCheckpoint) { eventActor ! DoCheckpoint(time) } else { + // If checkpointing is not enabled, then delete metadata information about + // received blocks (block data not saved in any case). Otherwise, wait for + // checkpointing of this batch to complete. + val maxRememberDuration = graph.getMaxInputStreamRememberDuration() + jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) markBatchFullyProcessed(time) } } @@ -252,6 +256,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream checkpoint data for the given `time`. */ private def clearCheckpointData(time: Time) { ssc.graph.clearCheckpointData(time) + + // All the checkpoint information about which batches have been processed, etc have + // been saved to checkpoints, so its safe to delete block metadata and data WAL files + val maxRememberDuration = graph.getMaxInputStreamRememberDuration() + jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) markBatchFullyProcessed(time) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index c3d9d7b6813d3..ef23b5c79f2e1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -150,7 +150,6 @@ private[streaming] class ReceivedBlockTracker( writeToLog(BatchCleanupEvent(timesToCleanup)) timeToAllocatedBlocks --= timesToCleanup logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds, waitForCompletion)) - log } /** Stop the block tracker. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 8dbb42a86e3bd..4f998869731ed 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -24,9 +24,8 @@ import scala.language.existentials import akka.actor._ import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} -import org.apache.spark.SparkContext._ import org.apache.spark.streaming.{StreamingContext, Time} -import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver} +import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver} /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -119,9 +118,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } } - /** Clean up metadata older than the given threshold time */ - def cleanupOldMetadata(cleanupThreshTime: Time) { + /** + * Clean up the data and metadata of blocks and batches that are strictly + * older than the threshold time. Note that this does not + */ + def cleanupOldBlocksAndBatches(cleanupThreshTime: Time) { + // Clean up old block and batch metadata receivedBlockTracker.cleanupOldBatches(cleanupThreshTime, waitForCompletion = false) + + // Signal the receivers to delete old block data + if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + logInfo(s"Cleanup old received batch data: $cleanupThreshTime") + receiverInfo.values.flatMap { info => Option(info.actor) } + .foreach { _ ! CleanupOldBlocks(cleanupThreshTime) } + } } /** Register a receiver */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala index 27a28bab83ed5..858ba3c9eb4e5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala @@ -63,7 +63,7 @@ private[streaming] object HdfsUtils { } def getFileSystemForPath(path: Path, conf: Configuration): FileSystem = { - // For local file systems, return the raw loca file system, such calls to flush() + // For local file systems, return the raw local file system, such calls to flush() // actually flushes the stream. val fs = path.getFileSystem(conf) fs match { diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 12cc0de7509d6..d92e7fe899a09 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -17,13 +17,20 @@ package org.apache.spark.streaming; +import java.io.*; +import java.lang.Iterable; +import java.nio.charset.Charset; +import java.util.*; + +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; import scala.Tuple2; import org.junit.Assert; +import static org.junit.Assert.*; import org.junit.Test; -import java.io.*; -import java.util.*; -import java.lang.Iterable; import com.google.common.base.Optional; import com.google.common.collect.Lists; @@ -1743,13 +1750,66 @@ public Iterable call(InputStream in) throws IOException { StorageLevel.MEMORY_ONLY()); } + @SuppressWarnings("unchecked") + @Test + public void testTextFileStream() throws IOException { + File testDir = Utils.createTempDir(System.getProperty("java.io.tmpdir")); + List> expected = fileTestPrepare(testDir); + + JavaDStream input = ssc.textFileStream(testDir.toString()); + JavaTestUtils.attachTestOutputStream(input); + List> result = JavaTestUtils.runStreams(ssc, 1, 1); + + assertOrderInvariantEquals(expected, result); + } + + @SuppressWarnings("unchecked") @Test - public void testTextFileStream() { - JavaDStream test = ssc.textFileStream("/tmp/foo"); + public void testFileStream() throws IOException { + File testDir = Utils.createTempDir(System.getProperty("java.io.tmpdir")); + List> expected = fileTestPrepare(testDir); + + JavaPairInputDStream inputStream = ssc.fileStream( + testDir.toString(), + LongWritable.class, + Text.class, + TextInputFormat.class, + new Function() { + @Override + public Boolean call(Path v1) throws Exception { + return Boolean.TRUE; + } + }, + true); + + JavaDStream test = inputStream.map( + new Function, String>() { + @Override + public String call(Tuple2 v1) throws Exception { + return v1._2().toString(); + } + }); + + JavaTestUtils.attachTestOutputStream(test); + List> result = JavaTestUtils.runStreams(ssc, 1, 1); + + assertOrderInvariantEquals(expected, result); } @Test public void testRawSocketStream() { JavaReceiverInputDStream test = ssc.rawSocketStream("localhost", 12345); } + + private List> fileTestPrepare(File testDir) throws IOException { + File existingFile = new File(testDir, "0"); + Files.write("0\n", existingFile, Charset.forName("UTF-8")); + assertTrue(existingFile.setLastModified(1000) && existingFile.lastModified() == 1000); + + List> expected = Arrays.asList( + Arrays.asList("0") + ); + + return expected; + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index e26c0c6859e57..e8c34a9ee40b9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -17,21 +17,26 @@ package org.apache.spark.streaming +import java.io.File import java.nio.ByteBuffer import java.util.concurrent.Semaphore +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.SparkConf -import org.apache.spark.storage.{StorageLevel, StreamBlockId} -import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver, ReceiverSupervisor} -import org.scalatest.FunSuite +import com.google.common.io.Files import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.receiver._ +import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ + /** Testsuite for testing the network receiver behavior */ -class ReceiverSuite extends FunSuite with Timeouts { +class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { test("receiver life cycle") { @@ -192,7 +197,6 @@ class ReceiverSuite extends FunSuite with Timeouts { val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3 val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1 val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",") - println(minExpectedMessagesPerBlock, maxExpectedMessagesPerBlock, ":", receivedBlockSizes) assert( // the first and last block may be incomplete, so we slice them out recordedBlocks.drop(1).dropRight(1).forall { block => @@ -203,39 +207,91 @@ class ReceiverSuite extends FunSuite with Timeouts { ) } - /** - * An implementation of NetworkReceiver that is used for testing a receiver's life cycle. + * Test whether write ahead logs are generated by received, + * and automatically cleaned up. The clean up must be aware of the + * remember duration of the input streams. E.g., input streams on which window() + * has been applied must remember the data for longer, and hence corresponding + * WALs should be cleaned later. */ - class FakeReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) { - @volatile var otherThread: Thread = null - @volatile var receiving = false - @volatile var onStartCalled = false - @volatile var onStopCalled = false - - def onStart() { - otherThread = new Thread() { - override def run() { - receiving = true - while(!isStopped()) { - Thread.sleep(10) - } + test("write ahead log - generating and cleaning") { + val sparkConf = new SparkConf() + .setMaster("local[4]") // must be at least 3 as we are going to start 2 receivers + .setAppName(framework) + .set("spark.ui.enabled", "true") + .set("spark.streaming.receiver.writeAheadLog.enable", "true") + .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + val batchDuration = Milliseconds(500) + val tempDirectory = Files.createTempDir() + val logDirectory1 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 0)) + val logDirectory2 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 1)) + val allLogFiles1 = new mutable.HashSet[String]() + val allLogFiles2 = new mutable.HashSet[String]() + logInfo("Temp checkpoint directory = " + tempDirectory) + + def getBothCurrentLogFiles(): (Seq[String], Seq[String]) = { + (getCurrentLogFiles(logDirectory1), getCurrentLogFiles(logDirectory2)) + } + + def getCurrentLogFiles(logDirectory: File): Seq[String] = { + try { + if (logDirectory.exists()) { + logDirectory1.listFiles().filter { _.getName.startsWith("log") }.map { _.toString } + } else { + Seq.empty } + } catch { + case e: Exception => + Seq.empty } - onStartCalled = true - otherThread.start() - } - def onStop() { - onStopCalled = true - otherThread.join() + def printLogFiles(message: String, files: Seq[String]) { + logInfo(s"$message (${files.size} files):\n" + files.mkString("\n")) } - def reset() { - receiving = false - onStartCalled = false - onStopCalled = false + withStreamingContext(new StreamingContext(sparkConf, batchDuration)) { ssc => + tempDirectory.deleteOnExit() + val receiver1 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) + val receiver2 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) + val receiverStream1 = ssc.receiverStream(receiver1) + val receiverStream2 = ssc.receiverStream(receiver2) + receiverStream1.register() + receiverStream2.window(batchDuration * 6).register() // 3 second window + ssc.checkpoint(tempDirectory.getAbsolutePath()) + ssc.start() + + // Run until sufficient WAL files have been generated and + // the first WAL files has been deleted + eventually(timeout(20 seconds), interval(batchDuration.milliseconds millis)) { + val (logFiles1, logFiles2) = getBothCurrentLogFiles() + allLogFiles1 ++= logFiles1 + allLogFiles2 ++= logFiles2 + if (allLogFiles1.size > 0) { + assert(!logFiles1.contains(allLogFiles1.toSeq.sorted.head)) + } + if (allLogFiles2.size > 0) { + assert(!logFiles2.contains(allLogFiles2.toSeq.sorted.head)) + } + assert(allLogFiles1.size >= 7) + assert(allLogFiles2.size >= 7) + } + ssc.stop(stopSparkContext = true, stopGracefully = true) + + val sortedAllLogFiles1 = allLogFiles1.toSeq.sorted + val sortedAllLogFiles2 = allLogFiles2.toSeq.sorted + val (leftLogFiles1, leftLogFiles2) = getBothCurrentLogFiles() + + printLogFiles("Receiver 0: all", sortedAllLogFiles1) + printLogFiles("Receiver 0: left", leftLogFiles1) + printLogFiles("Receiver 1: all", sortedAllLogFiles2) + printLogFiles("Receiver 1: left", leftLogFiles2) + + // Verify that necessary latest log files are not deleted + // receiverStream1 needs to retain just the last batch = 1 log file + // receiverStream2 needs to retain 3 seconds (3-seconds window) = 3 log files + assert(sortedAllLogFiles1.takeRight(1).forall(leftLogFiles1.contains)) + assert(sortedAllLogFiles2.takeRight(3).forall(leftLogFiles2.contains)) } } @@ -315,3 +371,42 @@ class ReceiverSuite extends FunSuite with Timeouts { } } +/** + * An implementation of Receiver that is used for testing a receiver's life cycle. + */ +class FakeReceiver(sendData: Boolean = false) extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + @volatile var otherThread: Thread = null + @volatile var receiving = false + @volatile var onStartCalled = false + @volatile var onStopCalled = false + + def onStart() { + otherThread = new Thread() { + override def run() { + receiving = true + var count = 0 + while(!isStopped()) { + if (sendData) { + store(count) + count += 1 + } + Thread.sleep(10) + } + } + } + onStartCalled = true + otherThread.start() + } + + def onStop() { + onStopCalled = true + otherThread.join() + } + + def reset() { + receiving = false + onStartCalled = false + onStopCalled = false + } +} + diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 79bead77ba6e4..f96b245512271 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -19,9 +19,9 @@ package org.apache.spark.deploy.yarn import scala.collection.mutable.ArrayBuffer -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import org.apache.spark.util.{Utils, IntParam, MemoryParam} +import org.apache.spark.util.{IntParam, MemoryParam, Utils} // TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware ! private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) { @@ -95,6 +95,10 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) throw new IllegalArgumentException( "You must specify at least 1 executor!\n" + getUsageMessage()) } + if (executorCores < sparkConf.getInt("spark.task.cpus", 1)) { + throw new SparkException("Executor cores must not be less than " + + "spark.task.cpus.") + } if (isClusterMode) { for (key <- Seq(amMemKey, amMemOverheadKey, amCoresKey)) { if (sparkConf.contains(key)) { @@ -222,7 +226,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) | --arg ARG Argument to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. | --num-executors NUM Number of executors to start (Default: 2) - | --executor-cores NUM Number of cores for the executors (Default: 1). + | --executor-cores NUM Number of cores per executor (Default: 1). | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb) | --driver-cores NUM Number of cores used by the driver (Default: 1). | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index de65ef23ad1ce..d00f29665a58f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -17,8 +17,8 @@ package org.apache.spark.deploy.yarn +import java.util.Collections import java.util.concurrent._ -import java.util.concurrent.atomic.AtomicInteger import java.util.regex.Pattern import scala.collection.JavaConversions._ @@ -28,33 +28,26 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records._ -import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.apache.hadoop.yarn.util.Records +import org.apache.hadoop.yarn.util.RackResolver -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import org.apache.spark.scheduler.{SplitInfo, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -object AllocationType extends Enumeration { - type AllocationType = Value - val HOST, RACK, ANY = Value -} - -// TODO: -// Too many params. -// Needs to be mt-safe -// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive - should -// make it more proactive and decoupled. - -// Note that right now, we assume all node asks as uniform in terms of capabilities and priority -// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for -// more info on how we are requesting for containers. - /** - * Acquires resources for executors from a ResourceManager and launches executors in new containers. + * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding + * what to do with containers when YARN fulfills these requests. + * + * This class makes use of YARN's AMRMClient APIs. We interact with the AMRMClient in three ways: + * * Making our resource needs known, which updates local bookkeeping about containers requested. + * * Calling "allocate", which syncs our local container requests with the RM, and returns any + * containers that YARN has granted to us. This also functions as a heartbeat. + * * Processing the containers granted to us to possibly launch executors inside of them. + * + * The public methods of this class are thread-safe. All methods that mutate state are + * synchronized. */ private[yarn] class YarnAllocator( conf: Configuration, @@ -62,50 +55,41 @@ private[yarn] class YarnAllocator( amClient: AMRMClient[ContainerRequest], appAttemptId: ApplicationAttemptId, args: ApplicationMasterArguments, - preferredNodes: collection.Map[String, collection.Set[SplitInfo]], securityMgr: SecurityManager) extends Logging { import YarnAllocator._ - // These three are locked on allocatedHostToContainersMap. Complementary data structures - // allocatedHostToContainersMap : containers which are running : host, Set - // allocatedContainerToHostMap: container to host mapping. - private val allocatedHostToContainersMap = - new HashMap[String, collection.mutable.Set[ContainerId]]() + // Visible for testing. + val allocatedHostToContainersMap = + new HashMap[String, collection.mutable.Set[ContainerId]] + val allocatedContainerToHostMap = new HashMap[ContainerId, String] - private val allocatedContainerToHostMap = new HashMap[ContainerId, String]() + // Containers that we no longer care about. We've either already told the RM to release them or + // will on the next heartbeat. Containers get removed from this map after the RM tells us they've + // completed. + private val releasedContainers = Collections.newSetFromMap[ContainerId]( + new ConcurrentHashMap[ContainerId, java.lang.Boolean]) - // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an - // allocated node) - // As with the two data structures above, tightly coupled with them, and to be locked on - // allocatedHostToContainersMap - private val allocatedRackCount = new HashMap[String, Int]() + @volatile private var numExecutorsRunning = 0 + // Used to generate a unique ID per executor + private var executorIdCounter = 0 + @volatile private var numExecutorsFailed = 0 - // Containers to be released in next request to RM - private val releasedContainers = new ConcurrentHashMap[ContainerId, Boolean] - - // Number of container requests that have been sent to, but not yet allocated by the - // ApplicationMaster. - private val numPendingAllocate = new AtomicInteger() - private val numExecutorsRunning = new AtomicInteger() - // Used to generate a unique id per executor - private val executorIdCounter = new AtomicInteger() - private val numExecutorsFailed = new AtomicInteger() - - private var maxExecutors = args.numExecutors + @volatile private var maxExecutors = args.numExecutors // Keep track of which container is running which executor to remove the executors later private val executorIdToContainer = new HashMap[String, Container] + // Executor memory in MB. protected val executorMemory = args.executorMemory - protected val executorCores = args.executorCores - protected val (preferredHostToCount, preferredRackToCount) = - generateNodeToWeight(conf, preferredNodes) - - // Additional memory overhead - in mb. + // Additional memory overhead. protected val memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)) + // Number of cores per executor. + protected val executorCores = args.executorCores + // Resource capability requested for each executors + private val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) private val launcherPool = new ThreadPoolExecutor( // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue @@ -115,26 +99,34 @@ private[yarn] class YarnAllocator( new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build()) launcherPool.allowCoreThreadTimeOut(true) - def getNumExecutorsRunning: Int = numExecutorsRunning.intValue + private val driverUrl = "akka.tcp://sparkDriver@%s:%s/user/%s".format( + sparkConf.get("spark.driver.host"), + sparkConf.get("spark.driver.port"), + CoarseGrainedSchedulerBackend.ACTOR_NAME) + + // For testing + private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) - def getNumExecutorsFailed: Int = numExecutorsFailed.intValue + def getNumExecutorsRunning: Int = numExecutorsRunning + + def getNumExecutorsFailed: Int = numExecutorsFailed + + /** + * Number of container requests that have not yet been fulfilled. + */ + def getNumPendingAllocate: Int = getNumPendingAtLocation(ANY_HOST) + + /** + * Number of container requests at the given location that have not yet been fulfilled. + */ + private def getNumPendingAtLocation(location: String): Int = + amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).map(_.size).sum /** * Request as many executors from the ResourceManager as needed to reach the desired total. - * This takes into account executors already running or pending. */ def requestTotalExecutors(requestedTotal: Int): Unit = synchronized { - val currentTotal = numPendingAllocate.get + numExecutorsRunning.get - if (requestedTotal > currentTotal) { - maxExecutors += (requestedTotal - currentTotal) - // We need to call `allocateResources` here to avoid the following race condition: - // If we request executors twice before `allocateResources` is called, then we will end up - // double counting the number requested because `numPendingAllocate` is not updated yet. - allocateResources() - } else { - logInfo(s"Not allocating more executors because there are already $currentTotal " + - s"(application requested $requestedTotal total)") - } + maxExecutors = requestedTotal } /** @@ -144,7 +136,7 @@ private[yarn] class YarnAllocator( if (executorIdToContainer.contains(executorId)) { val container = executorIdToContainer.remove(executorId).get internalReleaseContainer(container) - numExecutorsRunning.decrementAndGet() + numExecutorsRunning -= 1 maxExecutors -= 1 assert(maxExecutors >= 0, "Allocator killed more executors than are allocated!") } else { @@ -153,498 +145,234 @@ private[yarn] class YarnAllocator( } /** - * Allocate missing containers based on the number of executors currently pending and running. + * Request resources such that, if YARN gives us all we ask for, we'll have a number of containers + * equal to maxExecutors. * - * This method prioritizes the allocated container responses from the RM based on node and - * rack locality. Additionally, it releases any extra containers allocated for this application - * but are not needed. This must be synchronized because variables read in this block are - * mutated by other methods. + * Deal with any containers YARN has granted to us by possibly launching executors in them. + * + * This must be synchronized because variables read in this method are mutated by other methods. */ def allocateResources(): Unit = synchronized { - val missing = maxExecutors - numPendingAllocate.get() - numExecutorsRunning.get() + val numPendingAllocate = getNumPendingAllocate + val missing = maxExecutors - numPendingAllocate - numExecutorsRunning if (missing > 0) { - val totalExecutorMemory = executorMemory + memoryOverhead - numPendingAllocate.addAndGet(missing) - logInfo(s"Will allocate $missing executor containers, each with $totalExecutorMemory MB " + - s"memory including $memoryOverhead MB overhead") - } else { - logDebug("Empty allocation request ...") + logInfo(s"Will request $missing executor containers, each with ${resource.getVirtualCores} " + + s"cores and ${resource.getMemory} MB memory including $memoryOverhead MB overhead") } - val allocateResponse = allocateContainers(missing) + addResourceRequests(missing) + val progressIndicator = 0.1f + // Poll the ResourceManager. This doubles as a heartbeat if there are no pending container + // requests. + val allocateResponse = amClient.allocate(progressIndicator) + val allocatedContainers = allocateResponse.getAllocatedContainers() if (allocatedContainers.size > 0) { - var numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * allocatedContainers.size) - - if (numPendingAllocateNow < 0) { - numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * numPendingAllocateNow) - } - - logDebug(""" - Allocated containers: %d - Current executor count: %d - Containers released: %s - Cluster resources: %s - """.format( + logDebug("Allocated containers: %d. Current executor count: %d. Cluster resources: %s." + .format( allocatedContainers.size, - numExecutorsRunning.get(), - releasedContainers, + numExecutorsRunning, allocateResponse.getAvailableResources)) - val hostToContainers = new HashMap[String, ArrayBuffer[Container]]() - - for (container <- allocatedContainers) { - if (isResourceConstraintSatisfied(container)) { - // Add the accepted `container` to the host's list of already accepted, - // allocated containers - val host = container.getNodeId.getHost - val containersForHost = hostToContainers.getOrElseUpdate(host, - new ArrayBuffer[Container]()) - containersForHost += container - } else { - // Release container, since it doesn't satisfy resource constraints. - internalReleaseContainer(container) - } - } - - // Find the appropriate containers to use. - // TODO: Cleanup this group-by... - val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]() - val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]() - val offRackContainers = new HashMap[String, ArrayBuffer[Container]]() - - for (candidateHost <- hostToContainers.keySet) { - val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0) - val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost) - - val remainingContainersOpt = hostToContainers.get(candidateHost) - assert(remainingContainersOpt.isDefined) - var remainingContainers = remainingContainersOpt.get - - if (requiredHostCount >= remainingContainers.size) { - // Since we have <= required containers, add all remaining containers to - // `dataLocalContainers`. - dataLocalContainers.put(candidateHost, remainingContainers) - // There are no more free containers remaining. - remainingContainers = null - } else if (requiredHostCount > 0) { - // Container list has more containers than we need for data locality. - // Split the list into two: one based on the data local container count, - // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining - // containers. - val (dataLocal, remaining) = remainingContainers.splitAt( - remainingContainers.size - requiredHostCount) - dataLocalContainers.put(candidateHost, dataLocal) - - // Invariant: remainingContainers == remaining - - // YARN has a nasty habit of allocating a ton of containers on a host - discourage this. - // Add each container in `remaining` to list of containers to release. If we have an - // insufficient number of containers, then the next allocation cycle will reallocate - // (but won't treat it as data local). - // TODO(harvey): Rephrase this comment some more. - for (container <- remaining) internalReleaseContainer(container) - remainingContainers = null - } - - // For rack local containers - if (remainingContainers != null) { - val rack = YarnSparkHadoopUtil.lookupRack(conf, candidateHost) - if (rack != null) { - val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0) - val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - - rackLocalContainers.getOrElse(rack, List()).size - - if (requiredRackCount >= remainingContainers.size) { - // Add all remaining containers to to `dataLocalContainers`. - dataLocalContainers.put(rack, remainingContainers) - remainingContainers = null - } else if (requiredRackCount > 0) { - // Container list has more containers that we need for data locality. - // Split the list into two: one based on the data local container count, - // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining - // containers. - val (rackLocal, remaining) = remainingContainers.splitAt( - remainingContainers.size - requiredRackCount) - val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, - new ArrayBuffer[Container]()) - - existingRackLocal ++= rackLocal - - remainingContainers = remaining - } - } - } - - if (remainingContainers != null) { - // Not all containers have been consumed - add them to the list of off-rack containers. - offRackContainers.put(candidateHost, remainingContainers) - } - } - - // Now that we have split the containers into various groups, go through them in order: - // first host-local, then rack-local, and finally off-rack. - // Note that the list we create below tries to ensure that not all containers end up within - // a host if there is a sufficiently large number of hosts/containers. - val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size) - allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers) - allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers) - allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers) - - // Run each of the allocated containers. - for (container <- allocatedContainersToProcess) { - val numExecutorsRunningNow = numExecutorsRunning.incrementAndGet() - val executorHostname = container.getNodeId.getHost - val containerId = container.getId - - val executorMemoryOverhead = (executorMemory + memoryOverhead) - assert(container.getResource.getMemory >= executorMemoryOverhead) - - if (numExecutorsRunningNow > maxExecutors) { - logInfo("""Ignoring container %s at host %s, since we already have the required number of - containers for it.""".format(containerId, executorHostname)) - internalReleaseContainer(container) - numExecutorsRunning.decrementAndGet() - } else { - val executorId = executorIdCounter.incrementAndGet().toString - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( - SparkEnv.driverActorSystemName, - sparkConf.get("spark.driver.host"), - sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) - - logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) - executorIdToContainer(executorId) = container - - // To be safe, remove the container from `releasedContainers`. - releasedContainers.remove(containerId) - - val rack = YarnSparkHadoopUtil.lookupRack(conf, executorHostname) - allocatedHostToContainersMap.synchronized { - val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, - new HashSet[ContainerId]()) - - containerSet += containerId - allocatedContainerToHostMap.put(containerId, executorHostname) - - if (rack != null) { - allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) - } - } - logInfo("Launching ExecutorRunnable. driverUrl: %s, executorHostname: %s".format( - driverUrl, executorHostname)) - val executorRunnable = new ExecutorRunnable( - container, - conf, - sparkConf, - driverUrl, - executorId, - executorHostname, - executorMemory, - executorCores, - appAttemptId.getApplicationId.toString, - securityMgr) - launcherPool.execute(executorRunnable) - } - } - logDebug(""" - Finished allocating %s containers (from %s originally). - Current number of executors running: %d, - Released containers: %s - """.format( - allocatedContainersToProcess, - allocatedContainers, - numExecutorsRunning.get(), - releasedContainers)) + handleAllocatedContainers(allocatedContainers) } val completedContainers = allocateResponse.getCompletedContainersStatuses() if (completedContainers.size > 0) { logDebug("Completed %d containers".format(completedContainers.size)) - for (completedContainer <- completedContainers) { - val containerId = completedContainer.getContainerId + processCompletedContainers(completedContainers) - if (releasedContainers.containsKey(containerId)) { - // Already marked the container for release, so remove it from - // `releasedContainers`. - releasedContainers.remove(containerId) - } else { - // Decrement the number of executors running. The next iteration of - // the ApplicationMaster's reporting thread will take care of allocating. - numExecutorsRunning.decrementAndGet() - logInfo("Completed container %s (state: %s, exit status: %s)".format( - containerId, - completedContainer.getState, - completedContainer.getExitStatus)) - // Hadoop 2.2.X added a ContainerExitStatus we should switch to use - // there are some exit status' we shouldn't necessarily count against us, but for - // now I think its ok as none of the containers are expected to exit - if (completedContainer.getExitStatus == -103) { // vmem limit exceeded - logWarning(memLimitExceededLogMessage( - completedContainer.getDiagnostics, - VMEM_EXCEEDED_PATTERN)) - } else if (completedContainer.getExitStatus == -104) { // pmem limit exceeded - logWarning(memLimitExceededLogMessage( - completedContainer.getDiagnostics, - PMEM_EXCEEDED_PATTERN)) - } else if (completedContainer.getExitStatus != 0) { - logInfo("Container marked as failed: " + containerId + - ". Exit status: " + completedContainer.getExitStatus + - ". Diagnostics: " + completedContainer.getDiagnostics) - numExecutorsFailed.incrementAndGet() - } - } - - allocatedHostToContainersMap.synchronized { - if (allocatedContainerToHostMap.containsKey(containerId)) { - val hostOpt = allocatedContainerToHostMap.get(containerId) - assert(hostOpt.isDefined) - val host = hostOpt.get - - val containerSetOpt = allocatedHostToContainersMap.get(host) - assert(containerSetOpt.isDefined) - val containerSet = containerSetOpt.get - - containerSet.remove(containerId) - if (containerSet.isEmpty) { - allocatedHostToContainersMap.remove(host) - } else { - allocatedHostToContainersMap.update(host, containerSet) - } - - allocatedContainerToHostMap.remove(containerId) - - // TODO: Move this part outside the synchronized block? - val rack = YarnSparkHadoopUtil.lookupRack(conf, host) - if (rack != null) { - val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1 - if (rackCount > 0) { - allocatedRackCount.put(rack, rackCount) - } else { - allocatedRackCount.remove(rack) - } - } - } - } - } - logDebug(""" - Finished processing %d completed containers. - Current number of executors running: %d, - Released containers: %s - """.format( - completedContainers.size, - numExecutorsRunning.get(), - releasedContainers)) + logDebug("Finished processing %d completed containers. Current running executor count: %d." + .format(completedContainers.size, numExecutorsRunning)) } } - private def allocatedContainersOnHost(host: String): Int = { - allocatedHostToContainersMap.synchronized { - allocatedHostToContainersMap.getOrElse(host, Set()).size + /** + * Request numExecutors additional containers from YARN. Visible for testing. + */ + def addResourceRequests(numExecutors: Int): Unit = { + for (i <- 0 until numExecutors) { + val request = new ContainerRequest(resource, null, null, RM_REQUEST_PRIORITY) + amClient.addContainerRequest(request) + val nodes = request.getNodes + val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.last + logInfo("Container request (host: %s, capability: %s".format(hostStr, resource)) } } - private def allocatedContainersOnRack(rack: String): Int = { - allocatedHostToContainersMap.synchronized { - allocatedRackCount.getOrElse(rack, 0) + /** + * Handle containers granted by the RM by launching executors on them. + * + * Due to the way the YARN allocation protocol works, certain healthy race conditions can result + * in YARN granting containers that we no longer need. In this case, we release them. + * + * Visible for testing. + */ + def handleAllocatedContainers(allocatedContainers: Seq[Container]): Unit = { + val containersToUse = new ArrayBuffer[Container](allocatedContainers.size) + + // Match incoming requests by host + val remainingAfterHostMatches = new ArrayBuffer[Container] + for (allocatedContainer <- allocatedContainers) { + matchContainerToRequest(allocatedContainer, allocatedContainer.getNodeId.getHost, + containersToUse, remainingAfterHostMatches) } - } - private def isResourceConstraintSatisfied(container: Container): Boolean = { - container.getResource.getMemory >= (executorMemory + memoryOverhead) - } - - // A simple method to copy the split info map. - private def generateNodeToWeight( - conf: Configuration, - input: collection.Map[String, collection.Set[SplitInfo]]) - : (Map[String, Int], Map[String, Int]) = { - if (input == null) { - return (Map[String, Int](), Map[String, Int]()) + // Match remaining by rack + val remainingAfterRackMatches = new ArrayBuffer[Container] + for (allocatedContainer <- remainingAfterHostMatches) { + val rack = RackResolver.resolve(conf, allocatedContainer.getNodeId.getHost).getNetworkLocation + matchContainerToRequest(allocatedContainer, rack, containersToUse, + remainingAfterRackMatches) } - val hostToCount = new HashMap[String, Int] - val rackToCount = new HashMap[String, Int] - - for ((host, splits) <- input) { - val hostCount = hostToCount.getOrElse(host, 0) - hostToCount.put(host, hostCount + splits.size) + // Assign remaining that are neither node-local nor rack-local + val remainingAfterOffRackMatches = new ArrayBuffer[Container] + for (allocatedContainer <- remainingAfterRackMatches) { + matchContainerToRequest(allocatedContainer, ANY_HOST, containersToUse, + remainingAfterOffRackMatches) + } - val rack = YarnSparkHadoopUtil.lookupRack(conf, host) - if (rack != null) { - val rackCount = rackToCount.getOrElse(host, 0) - rackToCount.put(host, rackCount + splits.size) + if (!remainingAfterOffRackMatches.isEmpty) { + logDebug(s"Releasing ${remainingAfterOffRackMatches.size} unneeded containers that were " + + s"allocated to us") + for (container <- remainingAfterOffRackMatches) { + internalReleaseContainer(container) } } - (hostToCount.toMap, rackToCount.toMap) - } + runAllocatedContainers(containersToUse) - private def internalReleaseContainer(container: Container): Unit = { - releasedContainers.put(container.getId(), true) - amClient.releaseAssignedContainer(container.getId()) + logInfo("Received %d containers from YARN, launching executors on %d of them." + .format(allocatedContainers.size, containersToUse.size)) } /** - * Called to allocate containers in the cluster. + * Looks for requests for the given location that match the given container allocation. If it + * finds one, removes the request so that it won't be submitted again. Places the container into + * containersToUse or remaining. * - * @param count Number of containers to allocate. - * If zero, should still contact RM (as a heartbeat). - * @return Response to the allocation request. + * @param allocatedContainer container that was given to us by YARN + * @location resource name, either a node, rack, or * + * @param containersToUse list of containers that will be used + * @param remaining list of containers that will not be used */ - private def allocateContainers(count: Int): AllocateResponse = { - addResourceRequests(count) - - // We have already set the container request. Poll the ResourceManager for a response. - // This doubles as a heartbeat if there are no pending container requests. - val progressIndicator = 0.1f - amClient.allocate(progressIndicator) + private def matchContainerToRequest( + allocatedContainer: Container, + location: String, + containersToUse: ArrayBuffer[Container], + remaining: ArrayBuffer[Container]): Unit = { + val matchingRequests = amClient.getMatchingRequests(allocatedContainer.getPriority, location, + allocatedContainer.getResource) + + // Match the allocation to a request + if (!matchingRequests.isEmpty) { + val containerRequest = matchingRequests.get(0).iterator.next + amClient.removeContainerRequest(containerRequest) + containersToUse += allocatedContainer + } else { + remaining += allocatedContainer + } } - private def createRackResourceRequests(hostContainers: ArrayBuffer[ContainerRequest]) - : ArrayBuffer[ContainerRequest] = { - // Generate modified racks and new set of hosts under it before issuing requests. - val rackToCounts = new HashMap[String, Int]() - - for (container <- hostContainers) { - val candidateHost = container.getNodes.last - assert(YarnSparkHadoopUtil.ANY_HOST != candidateHost) - - val rack = YarnSparkHadoopUtil.lookupRack(conf, candidateHost) - if (rack != null) { - var count = rackToCounts.getOrElse(rack, 0) - count += 1 - rackToCounts.put(rack, count) + /** + * Launches executors in the allocated containers. + */ + private def runAllocatedContainers(containersToUse: ArrayBuffer[Container]): Unit = { + for (container <- containersToUse) { + numExecutorsRunning += 1 + assert(numExecutorsRunning <= maxExecutors) + val executorHostname = container.getNodeId.getHost + val containerId = container.getId + executorIdCounter += 1 + val executorId = executorIdCounter.toString + + assert(container.getResource.getMemory >= resource.getMemory) + + logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) + + val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, + new HashSet[ContainerId]) + + containerSet += containerId + allocatedContainerToHostMap.put(containerId, executorHostname) + + val executorRunnable = new ExecutorRunnable( + container, + conf, + sparkConf, + driverUrl, + executorId, + executorHostname, + executorMemory, + executorCores, + appAttemptId.getApplicationId.toString, + securityMgr) + if (launchContainers) { + logInfo("Launching ExecutorRunnable. driverUrl: %s, executorHostname: %s".format( + driverUrl, executorHostname)) + launcherPool.execute(executorRunnable) } } - - val requestedContainers = new ArrayBuffer[ContainerRequest](rackToCounts.size) - for ((rack, count) <- rackToCounts) { - requestedContainers ++= createResourceRequests( - AllocationType.RACK, - rack, - count, - RM_REQUEST_PRIORITY) - } - - requestedContainers } - private def addResourceRequests(numExecutors: Int): Unit = { - val containerRequests: List[ContainerRequest] = - if (numExecutors <= 0) { - logDebug("numExecutors: " + numExecutors) - List() - } else if (preferredHostToCount.isEmpty) { - logDebug("host preferences is empty") - createResourceRequests( - AllocationType.ANY, - resource = null, - numExecutors, - RM_REQUEST_PRIORITY).toList + private def processCompletedContainers(completedContainers: Seq[ContainerStatus]): Unit = { + for (completedContainer <- completedContainers) { + val containerId = completedContainer.getContainerId + + if (releasedContainers.contains(containerId)) { + // Already marked the container for release, so remove it from + // `releasedContainers`. + releasedContainers.remove(containerId) } else { - // Request for all hosts in preferred nodes and for numExecutors - - // candidates.size, request by default allocation policy. - val hostContainerRequests = new ArrayBuffer[ContainerRequest](preferredHostToCount.size) - for ((candidateHost, candidateCount) <- preferredHostToCount) { - val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost) - - if (requiredCount > 0) { - hostContainerRequests ++= createResourceRequests( - AllocationType.HOST, - candidateHost, - requiredCount, - RM_REQUEST_PRIORITY) - } + // Decrement the number of executors running. The next iteration of + // the ApplicationMaster's reporting thread will take care of allocating. + numExecutorsRunning -= 1 + logInfo("Completed container %s (state: %s, exit status: %s)".format( + containerId, + completedContainer.getState, + completedContainer.getExitStatus)) + // Hadoop 2.2.X added a ContainerExitStatus we should switch to use + // there are some exit status' we shouldn't necessarily count against us, but for + // now I think its ok as none of the containers are expected to exit + if (completedContainer.getExitStatus == -103) { // vmem limit exceeded + logWarning(memLimitExceededLogMessage( + completedContainer.getDiagnostics, + VMEM_EXCEEDED_PATTERN)) + } else if (completedContainer.getExitStatus == -104) { // pmem limit exceeded + logWarning(memLimitExceededLogMessage( + completedContainer.getDiagnostics, + PMEM_EXCEEDED_PATTERN)) + } else if (completedContainer.getExitStatus != 0) { + logInfo("Container marked as failed: " + containerId + + ". Exit status: " + completedContainer.getExitStatus + + ". Diagnostics: " + completedContainer.getDiagnostics) + numExecutorsFailed += 1 } - val rackContainerRequests: List[ContainerRequest] = createRackResourceRequests( - hostContainerRequests).toList - - val anyContainerRequests = createResourceRequests( - AllocationType.ANY, - resource = null, - numExecutors, - RM_REQUEST_PRIORITY) - - val containerRequestBuffer = new ArrayBuffer[ContainerRequest]( - hostContainerRequests.size + rackContainerRequests.size + anyContainerRequests.size) - - containerRequestBuffer ++= hostContainerRequests - containerRequestBuffer ++= rackContainerRequests - containerRequestBuffer ++= anyContainerRequests - containerRequestBuffer.toList } - for (request <- containerRequests) { - amClient.addContainerRequest(request) - } + if (allocatedContainerToHostMap.containsKey(containerId)) { + val host = allocatedContainerToHostMap.get(containerId).get + val containerSet = allocatedHostToContainersMap.get(host).get - for (request <- containerRequests) { - val nodes = request.getNodes - val hostStr = if (nodes == null || nodes.isEmpty) { - "Any" - } else { - nodes.last - } - logInfo("Container request (host: %s, priority: %s, capability: %s".format( - hostStr, - request.getPriority().getPriority, - request.getCapability)) - } - } + containerSet.remove(containerId) + if (containerSet.isEmpty) { + allocatedHostToContainersMap.remove(host) + } else { + allocatedHostToContainersMap.update(host, containerSet) + } - private def createResourceRequests( - requestType: AllocationType.AllocationType, - resource: String, - numExecutors: Int, - priority: Int): ArrayBuffer[ContainerRequest] = { - // If hostname is specified, then we need at least two requests - node local and rack local. - // There must be a third request, which is ANY. That will be specially handled. - requestType match { - case AllocationType.HOST => { - assert(YarnSparkHadoopUtil.ANY_HOST != resource) - val hostname = resource - val nodeLocal = constructContainerRequests( - Array(hostname), - racks = null, - numExecutors, - priority) - - // Add `hostname` to the global (singleton) host->rack mapping in YarnAllocationHandler. - YarnSparkHadoopUtil.populateRackInfo(conf, hostname) - nodeLocal - } - case AllocationType.RACK => { - val rack = resource - constructContainerRequests(hosts = null, Array(rack), numExecutors, priority) + allocatedContainerToHostMap.remove(containerId) } - case AllocationType.ANY => constructContainerRequests( - hosts = null, racks = null, numExecutors, priority) - case _ => throw new IllegalArgumentException( - "Unexpected/unsupported request type: " + requestType) } } - private def constructContainerRequests( - hosts: Array[String], - racks: Array[String], - numExecutors: Int, - priority: Int - ): ArrayBuffer[ContainerRequest] = { - val memoryRequest = executorMemory + memoryOverhead - val resource = Resource.newInstance(memoryRequest, executorCores) - - val prioritySetting = Records.newRecord(classOf[Priority]) - prioritySetting.setPriority(priority) - - val requests = new ArrayBuffer[ContainerRequest]() - for (i <- 0 until numExecutors) { - requests += new ContainerRequest(resource, hosts, racks, prioritySetting) - } - requests + private def internalReleaseContainer(container: Container): Unit = { + releasedContainers.add(container.getId()) + amClient.releaseAssignedContainer(container.getId()) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index b45e599588ad3..b134751366522 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -72,8 +72,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) registered = true } - new YarnAllocator(conf, sparkConf, amClient, getAttemptId(), args, - preferredNodeLocations, securityMgr) + new YarnAllocator(conf, sparkConf, amClient, getAttemptId(), args, securityMgr) } /** diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index d7cf904db1c9e..4bff846123619 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.api.records.ApplicationAccessType +import org.apache.hadoop.yarn.api.records.{Priority, ApplicationAccessType} import org.apache.hadoop.yarn.util.RackResolver import org.apache.hadoop.conf.Configuration @@ -99,13 +99,7 @@ object YarnSparkHadoopUtil { // All RM requests are issued with same priority : we do not (yet) have any distinction between // request types (like map/reduce in hadoop for example) - val RM_REQUEST_PRIORITY = 1 - - // Host to rack map - saved from allocation requests. We are expecting this not to change. - // Note that it is possible for this to change : and ResourceManager will indicate that to us via - // update response to allocate. But we are punting on handling that for now. - private val hostToRack = new ConcurrentHashMap[String, String]() - private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]() + val RM_REQUEST_PRIORITY = Priority.newInstance(1) /** * Add a path variable to the given environment map. @@ -184,37 +178,6 @@ object YarnSparkHadoopUtil { } } - def lookupRack(conf: Configuration, host: String): String = { - if (!hostToRack.contains(host)) { - populateRackInfo(conf, host) - } - hostToRack.get(host) - } - - def populateRackInfo(conf: Configuration, hostname: String) { - Utils.checkHost(hostname) - - if (!hostToRack.containsKey(hostname)) { - // If there are repeated failures to resolve, all to an ignore list. - val rackInfo = RackResolver.resolve(conf, hostname) - if (rackInfo != null && rackInfo.getNetworkLocation != null) { - val rack = rackInfo.getNetworkLocation - hostToRack.put(hostname, rack) - if (! rackToHostSet.containsKey(rack)) { - rackToHostSet.putIfAbsent(rack, - Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]())) - } - rackToHostSet.get(rack).add(hostname) - - // TODO(harvey): Figure out what this comment means... - // Since RackResolver caches, we are disabling this for now ... - } /* else { - // right ? Else we will keep calling rack resolver in case we cant resolve rack info ... - hostToRack.put(hostname, null) - } */ - } - } - def getApplicationAclsForYarn(securityMgr: SecurityManager) : Map[ApplicationAccessType, String] = { Map[ApplicationAccessType, String] ( diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala index 254774a6b839e..2fa24cc43325e 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala @@ -17,8 +17,9 @@ package org.apache.spark.scheduler.cluster +import org.apache.hadoop.yarn.util.RackResolver + import org.apache.spark._ -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils @@ -30,6 +31,6 @@ private[spark] class YarnClientClusterScheduler(sc: SparkContext) extends TaskSc // By default, rack is unknown override def getRackForHost(hostPort: String): Option[String] = { val host = Utils.parseHostPort(hostPort)._1 - Option(YarnSparkHadoopUtil.lookupRack(sc.hadoopConfiguration, host)) + Option(RackResolver.resolve(sc.hadoopConfiguration, host).getNetworkLocation) } } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index 4157ff95c2794..be55d26f1cf61 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -17,8 +17,10 @@ package org.apache.spark.scheduler.cluster +import org.apache.hadoop.yarn.util.RackResolver + import org.apache.spark._ -import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnSparkHadoopUtil} +import org.apache.spark.deploy.yarn.ApplicationMaster import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils @@ -39,7 +41,7 @@ private[spark] class YarnClusterScheduler(sc: SparkContext) extends TaskSchedule // By default, rack is unknown override def getRackForHost(hostPort: String): Option[String] = { val host = Utils.parseHostPort(hostPort)._1 - Option(YarnSparkHadoopUtil.lookupRack(sc.hadoopConfiguration, host)) + Option(RackResolver.resolve(sc.hadoopConfiguration, host).getNetworkLocation) } override def postStartHook() { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 8d184a09d64cc..024b25f9d3365 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -17,18 +17,160 @@ package org.apache.spark.deploy.yarn +import java.util.{Arrays, List => JList} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.CommonConfigurationKeysPublic +import org.apache.hadoop.net.DNSToSwitchMapping +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest + +import org.apache.spark.SecurityManager +import org.apache.spark.SparkConf +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.YarnAllocator._ -import org.scalatest.FunSuite +import org.apache.spark.scheduler.SplitInfo + +import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} + +class MockResolver extends DNSToSwitchMapping { + + override def resolve(names: JList[String]): JList[String] = { + if (names.size > 0 && names.get(0) == "host3") Arrays.asList("/rack2") + else Arrays.asList("/rack1") + } + + override def reloadCachedMappings() {} + + def reloadCachedMappings(names: JList[String]) {} +} + +class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach { + val conf = new Configuration() + conf.setClass( + CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, + classOf[MockResolver], classOf[DNSToSwitchMapping]) + + val sparkConf = new SparkConf() + sparkConf.set("spark.driver.host", "localhost") + sparkConf.set("spark.driver.port", "4040") + sparkConf.set("spark.yarn.jar", "notarealjar.jar") + sparkConf.set("spark.yarn.launchContainers", "false") + + val appAttemptId = ApplicationAttemptId.newInstance(ApplicationId.newInstance(0, 0), 0) + + // Resource returned by YARN. YARN can give larger containers than requested, so give 6 cores + // instead of the 5 requested and 3 GB instead of the 2 requested. + val containerResource = Resource.newInstance(3072, 6) + + var rmClient: AMRMClient[ContainerRequest] = _ + + var containerNum = 0 + + override def beforeEach() { + rmClient = AMRMClient.createAMRMClient() + rmClient.init(conf) + rmClient.start() + } + + override def afterEach() { + rmClient.stop() + } + + class MockSplitInfo(host: String) extends SplitInfo(null, host, null, 1, null) { + override def equals(other: Any) = false + } + + def createAllocator(maxExecutors: Int = 5): YarnAllocator = { + val args = Array( + "--num-executors", s"$maxExecutors", + "--executor-cores", "5", + "--executor-memory", "2048", + "--jar", "somejar.jar", + "--class", "SomeClass") + new YarnAllocator( + conf, + sparkConf, + rmClient, + appAttemptId, + new ApplicationMasterArguments(args), + new SecurityManager(sparkConf)) + } + + def createContainer(host: String): Container = { + val containerId = ContainerId.newInstance(appAttemptId, containerNum) + containerNum += 1 + val nodeId = NodeId.newInstance(host, 1000) + Container.newInstance(containerId, nodeId, "", containerResource, RM_REQUEST_PRIORITY, null) + } + + test("single container allocated") { + // request a single container and receive it + val handler = createAllocator() + handler.addResourceRequests(1) + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (1) + + val container = createContainer("host1") + handler.handleAllocatedContainers(Array(container)) + + handler.getNumExecutorsRunning should be (1) + handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") + handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId) + rmClient.getMatchingRequests(container.getPriority, "host1", containerResource).size should be (0) + } + + test("some containers allocated") { + // request a few containers and receive some of them + val handler = createAllocator() + handler.addResourceRequests(4) + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + val container1 = createContainer("host1") + val container2 = createContainer("host1") + val container3 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2, container3)) + + handler.getNumExecutorsRunning should be (3) + handler.allocatedContainerToHostMap.get(container1.getId).get should be ("host1") + handler.allocatedContainerToHostMap.get(container2.getId).get should be ("host1") + handler.allocatedContainerToHostMap.get(container3.getId).get should be ("host2") + handler.allocatedHostToContainersMap.get("host1").get should contain (container1.getId) + handler.allocatedHostToContainersMap.get("host1").get should contain (container2.getId) + handler.allocatedHostToContainersMap.get("host2").get should contain (container3.getId) + } + + test("receive more containers than requested") { + val handler = createAllocator(2) + handler.addResourceRequests(2) + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (2) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + val container3 = createContainer("host4") + handler.handleAllocatedContainers(Array(container1, container2, container3)) + + handler.getNumExecutorsRunning should be (2) + handler.allocatedContainerToHostMap.get(container1.getId).get should be ("host1") + handler.allocatedContainerToHostMap.get(container2.getId).get should be ("host2") + handler.allocatedContainerToHostMap.contains(container3.getId) should be (false) + handler.allocatedHostToContainersMap.get("host1").get should contain (container1.getId) + handler.allocatedHostToContainersMap.get("host2").get should contain (container2.getId) + handler.allocatedHostToContainersMap.contains("host4") should be (false) + } -class YarnAllocatorSuite extends FunSuite { test("memory exceeded diagnostic regexes") { val diagnostics = "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " + - "beyond physical memory limits. Current usage: 2.1 MB of 2 GB physical memory used; " + - "5.8 GB of 4.2 GB virtual memory used. Killing container." + "beyond physical memory limits. Current usage: 2.1 MB of 2 GB physical memory used; " + + "5.8 GB of 4.2 GB virtual memory used. Killing container." val vmemMsg = memLimitExceededLogMessage(diagnostics, VMEM_EXCEEDED_PATTERN) val pmemMsg = memLimitExceededLogMessage(diagnostics, PMEM_EXCEEDED_PATTERN) assert(vmemMsg.contains("5.8 GB of 4.2 GB virtual memory used.")) assert(pmemMsg.contains("2.1 MB of 2 GB physical memory used.")) } + }