From c93a57f0d6dc32b127aa68dbe4092ab0b22a9667 Mon Sep 17 00:00:00 2001 From: Jacek Lewandowski Date: Tue, 20 Jan 2015 12:38:01 -0800 Subject: [PATCH 01/27] SPARK-4660: Use correct class loader in JavaSerializer (copy of PR #3840... ... by Piotr Kolaczkowski) Author: Jacek Lewandowski Closes #4113 from jacek-lewandowski/SPARK-4660-master and squashes the following commits: a5e84ca [Jacek Lewandowski] SPARK-4660: Use correct class loader in JavaSerializer (copy of PR #3840 by Piotr Kolaczkowski) --- .../main/scala/org/apache/spark/serializer/JavaSerializer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 662a7b91248aa..fa8a337ad63a8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -92,7 +92,7 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade } override def deserializeStream(s: InputStream): DeserializationStream = { - new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader) + new JavaDeserializationStream(s, defaultClassLoader) } def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = { From 769aced9e7f058f5008ce405f7c9714c3db203be Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 20 Jan 2015 12:40:55 -0800 Subject: [PATCH 02/27] [SPARK-5329][WebUI] UIWorkloadGenerator should stop SparkContext. UIWorkloadGenerator don't stop SparkContext. I ran UIWorkloadGenerator and try to watch the result at WebUI but Jobs are marked as finished. It's because SparkContext is not stopped. Author: Kousuke Saruta Closes #4112 from sarutak/SPARK-5329 and squashes the following commits: bcc0fa9 [Kousuke Saruta] Disabled scalastyle for a bock comment 86a3b95 [Kousuke Saruta] Fixed UIWorkloadGenerator to stop SparkContext in it --- .../org/apache/spark/ui/UIWorkloadGenerator.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index b4677447c8872..fc1844600f1cb 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -22,20 +22,23 @@ import scala.util.Random import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.scheduler.SchedulingMode +// scalastyle:off /** * Continuously generates jobs that expose various features of the WebUI (internal testing tool). * - * Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR] + * Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR] [#job set (4 jobs per set)] */ +// scalastyle:on private[spark] object UIWorkloadGenerator { val NUM_PARTITIONS = 100 val INTER_JOB_WAIT_MS = 5000 def main(args: Array[String]) { - if (args.length < 2) { + if (args.length < 3) { println( - "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]") + "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " + + "[master] [FIFO|FAIR] [#job set (4 jobs per set)]") System.exit(1) } @@ -45,6 +48,7 @@ private[spark] object UIWorkloadGenerator { if (schedulingMode == SchedulingMode.FAIR) { conf.set("spark.scheduler.mode", "FAIR") } + val nJobSet = args(2).toInt val sc = new SparkContext(conf) def setProperties(s: String) = { @@ -84,7 +88,7 @@ private[spark] object UIWorkloadGenerator { ("Job with delays", baseData.map(x => Thread.sleep(100)).count) ) - while (true) { + (1 to nJobSet).foreach { _ => for ((desc, job) <- jobs) { new Thread { override def run() { @@ -101,5 +105,6 @@ private[spark] object UIWorkloadGenerator { Thread.sleep(INTER_JOB_WAIT_MS) } } + sc.stop() } } From 23e25543beaa5966b5f07365f338ce338fd6d71f Mon Sep 17 00:00:00 2001 From: Travis Galoppo Date: Tue, 20 Jan 2015 12:58:11 -0800 Subject: [PATCH 03/27] SPARK-5019 [MLlib] - GaussianMixtureModel exposes instances of MultivariateGauss... This PR modifies GaussianMixtureModel to expose instances of MutlivariateGaussian rather than separate mean and covariance arrays. Author: Travis Galoppo Closes #4088 from tgaloppo/spark-5019 and squashes the following commits: 3ef6c7f [Travis Galoppo] In GaussianMixtureModel: Changed name of weight, gaussian to weights, gaussians. Other sources modified accordingly. 091e8da [Travis Galoppo] SPARK-5019 - GaussianMixtureModel exposes instances of MultivariateGaussian rather than mean/covariance matrices --- .../spark/examples/mllib/DenseGmmEM.scala | 2 +- .../mllib/clustering/GaussianMixtureEM.scala | 9 ++----- .../clustering/GaussianMixtureModel.scala | 21 +++++++--------- .../GMMExpectationMaximizationSuite.scala | 25 +++++++++++-------- 4 files changed, 26 insertions(+), 31 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala index 948c350953e27..de58be38c7bfb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala @@ -54,7 +54,7 @@ object DenseGmmEM { for (i <- 0 until clusters.k) { println("weight=%f\nmu=%s\nsigma=\n%s\n" format - (clusters.weight(i), clusters.mu(i), clusters.sigma(i))) + (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma)) } println("Cluster labels (first <= 100):") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala index d8e134619411b..899fe5e9e9cf2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala @@ -134,9 +134,7 @@ class GaussianMixtureEM private ( // diagonal covariance matrices using component variances // derived from the samples val (weights, gaussians) = initialModel match { - case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) => - new MultivariateGaussian(mu, sigma) - }) + case Some(gmm) => (gmm.weights, gmm.gaussians) case None => { val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) @@ -176,10 +174,7 @@ class GaussianMixtureEM private ( iter += 1 } - // Need to convert the breeze matrices to MLlib matrices - val means = Array.tabulate(k) { i => gaussians(i).mu } - val sigmas = Array.tabulate(k) { i => gaussians(i).sigma } - new GaussianMixtureModel(weights, means, sigmas) + new GaussianMixtureModel(weights, gaussians) } /** Average of dense breeze vectors */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 416cad080c408..1a2178ee7f711 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BreezeVector} import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils @@ -36,12 +36,13 @@ import org.apache.spark.mllib.util.MLUtils * covariance matrix for Gaussian i */ class GaussianMixtureModel( - val weight: Array[Double], - val mu: Array[Vector], - val sigma: Array[Matrix]) extends Serializable { + val weights: Array[Double], + val gaussians: Array[MultivariateGaussian]) extends Serializable { + + require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") /** Number of gaussians in mixture */ - def k: Int = weight.length + def k: Int = weights.length /** Maps given points to their cluster indices. */ def predict(points: RDD[Vector]): RDD[Int] = { @@ -55,14 +56,10 @@ class GaussianMixtureModel( */ def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext - val dists = sc.broadcast { - (0 until k).map { i => - new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix) - }.toArray - } - val weights = sc.broadcast(weight) + val bcDists = sc.broadcast(gaussians) + val bcWeights = sc.broadcast(weights) points.map { x => - computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k) + computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala index 9da5495741a80..198997b5bb2b2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{Vectors, Matrices} +import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -39,9 +40,9 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex val seeds = Array(314589, 29032897, 50181, 494821, 4660) seeds.foreach { seed => val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data) - assert(gmm.weight(0) ~== Ew absTol 1E-5) - assert(gmm.mu(0) ~== Emu absTol 1E-5) - assert(gmm.sigma(0) ~== Esigma absTol 1E-5) + assert(gmm.weights(0) ~== Ew absTol 1E-5) + assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5) + assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5) } } @@ -57,8 +58,10 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex // we set an initial gaussian to induce expected results val initialGmm = new GaussianMixtureModel( Array(0.5, 0.5), - Array(Vectors.dense(-1.0), Vectors.dense(1.0)), - Array(Matrices.dense(1, 1, Array(1.0)), Matrices.dense(1, 1, Array(1.0))) + Array( + new MultivariateGaussian(Vectors.dense(-1.0), Matrices.dense(1, 1, Array(1.0))), + new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0))) + ) ) val Ew = Array(1.0 / 3.0, 2.0 / 3.0) @@ -70,11 +73,11 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex .setInitialModel(initialGmm) .run(data) - assert(gmm.weight(0) ~== Ew(0) absTol 1E-3) - assert(gmm.weight(1) ~== Ew(1) absTol 1E-3) - assert(gmm.mu(0) ~== Emu(0) absTol 1E-3) - assert(gmm.mu(1) ~== Emu(1) absTol 1E-3) - assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3) - assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3) + assert(gmm.weights(0) ~== Ew(0) absTol 1E-3) + assert(gmm.weights(1) ~== Ew(1) absTol 1E-3) + assert(gmm.gaussians(0).mu ~== Emu(0) absTol 1E-3) + assert(gmm.gaussians(1).mu ~== Emu(1) absTol 1E-3) + assert(gmm.gaussians(0).sigma ~== Esigma(0) absTol 1E-3) + assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3) } } From bc20a52b34e826895d0dcc1d783c021ebd456ebd Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 20 Jan 2015 13:26:36 -0800 Subject: [PATCH 04/27] [SPARK-5287][SQL] Add defaultSizeOf to every data type. JIRA: https://issues.apache.org/jira/browse/SPARK-5287 This PR only add `defaultSizeOf` to data types and make those internal type classes `protected[sql]`. I will use another PR to cleanup the type hierarchy of data types. Author: Yin Huai Closes #4081 from yhuai/SPARK-5287 and squashes the following commits: 90cec75 [Yin Huai] Update unit test. e1c600c [Yin Huai] Make internal classes protected[sql]. 7eaba68 [Yin Huai] Add `defaultSize` method to data types. fd425e0 [Yin Huai] Add all native types to NativeType.defaultSizeOf. --- .../catalyst/expressions/WrapDynamic.scala | 8 +- .../plans/logical/basicOperators.scala | 15 +-- .../apache/spark/sql/types/dataTypes.scala | 120 +++++++++++++++--- .../spark/sql/types/DataTypeSuite.scala | 40 +++++- .../spark/sql/execution/PlannerSuite.scala | 66 ++++++++-- 5 files changed, 201 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index 8328278544a1e..e2f5c7332d9ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -24,7 +24,13 @@ import org.apache.spark.sql.types.DataType /** * The data type representing [[DynamicRow]] values. */ -case object DynamicType extends DataType +case object DynamicType extends DataType { + + /** + * The default size of a value of the DynamicType is 4096 bytes. + */ + override def defaultSize: Int = 4096 +} /** * Wrap a [[Row]] as a [[DynamicRow]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 1483beacc9088..9628e93274a11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -238,16 +238,11 @@ case class Rollup( case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output = child.output - override lazy val statistics: Statistics = - if (output.forall(_.dataType.isInstanceOf[NativeType])) { - val limit = limitExpr.eval(null).asInstanceOf[Int] - val sizeInBytes = (limit: Long) * output.map { a => - NativeType.defaultSizeOf(a.dataType.asInstanceOf[NativeType]) - }.sum - Statistics(sizeInBytes = sizeInBytes) - } else { - Statistics(sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product) - } + override lazy val statistics: Statistics = { + val limit = limitExpr.eval(null).asInstanceOf[Int] + val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum + Statistics(sizeInBytes = sizeInBytes) + } } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index bcd74603d4013..9f30f40a173e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -215,6 +215,9 @@ abstract class DataType { case _ => false } + /** The default size of a value of this data type. */ + def defaultSize: Int + def isPrimitive: Boolean = false def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase @@ -235,33 +238,25 @@ abstract class DataType { * @group dataType */ @DeveloperApi -case object NullType extends DataType +case object NullType extends DataType { + override def defaultSize: Int = 1 +} -object NativeType { +protected[sql] object NativeType { val all = Seq( IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) def unapply(dt: DataType): Boolean = all.contains(dt) - - val defaultSizeOf: Map[NativeType, Int] = Map( - IntegerType -> 4, - BooleanType -> 1, - LongType -> 8, - DoubleType -> 8, - FloatType -> 4, - ShortType -> 2, - ByteType -> 1, - StringType -> 4096) } -trait PrimitiveType extends DataType { +protected[sql] trait PrimitiveType extends DataType { override def isPrimitive = true } -object PrimitiveType { +protected[sql] object PrimitiveType { private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap @@ -276,7 +271,7 @@ object PrimitiveType { } } -abstract class NativeType extends DataType { +protected[sql] abstract class NativeType extends DataType { private[sql] type JvmType @transient private[sql] val tag: TypeTag[JvmType] private[sql] val ordering: Ordering[JvmType] @@ -300,6 +295,11 @@ case object StringType extends NativeType with PrimitiveType { private[sql] type JvmType = String @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the StringType is 4096 bytes. + */ + override def defaultSize: Int = 4096 } @@ -324,6 +324,11 @@ case object BinaryType extends NativeType with PrimitiveType { x.length - y.length } } + + /** + * The default size of a value of the BinaryType is 4096 bytes. + */ + override def defaultSize: Int = 4096 } @@ -339,6 +344,11 @@ case object BooleanType extends NativeType with PrimitiveType { private[sql] type JvmType = Boolean @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the BooleanType is 1 byte. + */ + override def defaultSize: Int = 1 } @@ -359,6 +369,11 @@ case object TimestampType extends NativeType { private[sql] val ordering = new Ordering[JvmType] { def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) } + + /** + * The default size of a value of the TimestampType is 8 bytes. + */ + override def defaultSize: Int = 8 } @@ -379,10 +394,15 @@ case object DateType extends NativeType { private[sql] val ordering = new Ordering[JvmType] { def compare(x: Date, y: Date) = x.compareTo(y) } + + /** + * The default size of a value of the DateType is 8 bytes. + */ + override def defaultSize: Int = 8 } -abstract class NumericType extends NativeType with PrimitiveType { +protected[sql] abstract class NumericType extends NativeType with PrimitiveType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets @@ -392,13 +412,13 @@ abstract class NumericType extends NativeType with PrimitiveType { } -object NumericType { +protected[sql] object NumericType { def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] } /** Matcher for any expressions that evaluate to [[IntegralType]]s */ -object IntegralType { +protected[sql] object IntegralType { def unapply(a: Expression): Boolean = a match { case e: Expression if e.dataType.isInstanceOf[IntegralType] => true case _ => false @@ -406,7 +426,7 @@ object IntegralType { } -sealed abstract class IntegralType extends NumericType { +protected[sql] sealed abstract class IntegralType extends NumericType { private[sql] val integral: Integral[JvmType] } @@ -425,6 +445,11 @@ case object LongType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Long]] private[sql] val integral = implicitly[Integral[Long]] private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the LongType is 8 bytes. + */ + override def defaultSize: Int = 8 } @@ -442,6 +467,11 @@ case object IntegerType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Int]] private[sql] val integral = implicitly[Integral[Int]] private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the IntegerType is 4 bytes. + */ + override def defaultSize: Int = 4 } @@ -459,6 +489,11 @@ case object ShortType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Short]] private[sql] val integral = implicitly[Integral[Short]] private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the ShortType is 2 bytes. + */ + override def defaultSize: Int = 2 } @@ -476,11 +511,16 @@ case object ByteType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Byte]] private[sql] val integral = implicitly[Integral[Byte]] private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the ByteType is 1 byte. + */ + override def defaultSize: Int = 1 } /** Matcher for any expressions that evaluate to [[FractionalType]]s */ -object FractionalType { +protected[sql] object FractionalType { def unapply(a: Expression): Boolean = a match { case e: Expression if e.dataType.isInstanceOf[FractionalType] => true case _ => false @@ -488,7 +528,7 @@ object FractionalType { } -sealed abstract class FractionalType extends NumericType { +protected[sql] sealed abstract class FractionalType extends NumericType { private[sql] val fractional: Fractional[JvmType] private[sql] val asIntegral: Integral[JvmType] } @@ -530,6 +570,11 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" case None => "DecimalType()" } + + /** + * The default size of a value of the DecimalType is 4096 bytes. + */ + override def defaultSize: Int = 4096 } @@ -580,6 +625,11 @@ case object DoubleType extends FractionalType { private[sql] val fractional = implicitly[Fractional[Double]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = DoubleAsIfIntegral + + /** + * The default size of a value of the DoubleType is 8 bytes. + */ + override def defaultSize: Int = 8 } @@ -598,6 +648,11 @@ case object FloatType extends FractionalType { private[sql] val fractional = implicitly[Fractional[Float]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = FloatAsIfIntegral + + /** + * The default size of a value of the FloatType is 4 bytes. + */ + override def defaultSize: Int = 4 } @@ -636,6 +691,12 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT ("type" -> typeName) ~ ("elementType" -> elementType.jsonValue) ~ ("containsNull" -> containsNull) + + /** + * The default size of a value of the ArrayType is 100 * the default size of the element type. + * (We assume that there are 100 elements). + */ + override def defaultSize: Int = 100 * elementType.defaultSize } @@ -805,6 +866,11 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override def length: Int = fields.length override def iterator: Iterator[StructField] = fields.iterator + + /** + * The default size of a value of the StructType is the total default sizes of all field types. + */ + override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum } @@ -848,6 +914,13 @@ case class MapType( ("keyType" -> keyType.jsonValue) ~ ("valueType" -> valueType.jsonValue) ~ ("valueContainsNull" -> valueContainsNull) + + /** + * The default size of a value of the MapType is + * 100 * (the default size of the key type + the default size of the value type). + * (We assume that there are 100 elements). + */ + override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize) } @@ -896,4 +969,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { * Class object for the UserType */ def userClass: java.lang.Class[UserType] + + /** + * The default size of a value of the UserDefinedType is 4096 bytes. + */ + override def defaultSize: Int = 4096 } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 892195f46ea24..c147be9f6b1ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -62,6 +62,7 @@ class DataTypeSuite extends FunSuite { } } + checkDataTypeJsonRepr(NullType) checkDataTypeJsonRepr(BooleanType) checkDataTypeJsonRepr(ByteType) checkDataTypeJsonRepr(ShortType) @@ -69,7 +70,9 @@ class DataTypeSuite extends FunSuite { checkDataTypeJsonRepr(LongType) checkDataTypeJsonRepr(FloatType) checkDataTypeJsonRepr(DoubleType) + checkDataTypeJsonRepr(DecimalType(10, 5)) checkDataTypeJsonRepr(DecimalType.Unlimited) + checkDataTypeJsonRepr(DateType) checkDataTypeJsonRepr(TimestampType) checkDataTypeJsonRepr(StringType) checkDataTypeJsonRepr(BinaryType) @@ -77,12 +80,39 @@ class DataTypeSuite extends FunSuite { checkDataTypeJsonRepr(ArrayType(StringType, false)) checkDataTypeJsonRepr(MapType(IntegerType, StringType, true)) checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false)) + val metadata = new MetadataBuilder() .putString("name", "age") .build() - checkDataTypeJsonRepr( - StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", ArrayType(DoubleType), nullable = false), - StructField("c", DoubleType, nullable = false, metadata)))) + val structType = StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", ArrayType(DoubleType), nullable = false), + StructField("c", DoubleType, nullable = false, metadata))) + checkDataTypeJsonRepr(structType) + + def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = { + test(s"Check the default size of ${dataType}") { + assert(dataType.defaultSize === expectedDefaultSize) + } + } + + checkDefaultSize(NullType, 1) + checkDefaultSize(BooleanType, 1) + checkDefaultSize(ByteType, 1) + checkDefaultSize(ShortType, 2) + checkDefaultSize(IntegerType, 4) + checkDefaultSize(LongType, 8) + checkDefaultSize(FloatType, 4) + checkDefaultSize(DoubleType, 8) + checkDefaultSize(DecimalType(10, 5), 4096) + checkDefaultSize(DecimalType.Unlimited, 4096) + checkDefaultSize(DateType, 8) + checkDefaultSize(TimestampType, 8) + checkDefaultSize(StringType, 4096) + checkDefaultSize(BinaryType, 4096) + checkDefaultSize(ArrayType(DoubleType, true), 800) + checkDefaultSize(ArrayType(StringType, false), 409600) + checkDefaultSize(MapType(IntegerType, StringType, true), 410000) + checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400) + checkDefaultSize(structType, 812) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index c5b6fce5fd297..67007b8c093ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ +import org.apache.spark.sql.types._ class PlannerSuite extends FunSuite { test("unions are collapsed") { @@ -60,19 +61,62 @@ class PlannerSuite extends FunSuite { } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { - val origThreshold = conf.autoBroadcastJoinThreshold - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString) - - // Using a threshold that is definitely larger than the small testing table (b) below - val a = testData.as('a) - val b = testData.limit(3).as('b) - val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan + def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = { + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold.toString) + val fields = fieldTypes.zipWithIndex.map { + case (dataType, index) => StructField(s"c${index}", dataType, true) + } :+ StructField("key", IntegerType, true) + val schema = StructType(fields) + val row = Row.fromSeq(Seq.fill(fields.size)(null)) + val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil) + applySchema(rowRDD, schema).registerTempTable("testLimit") + + val planned = sql( + """ + |SELECT l.a, l.b + |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) + """.stripMargin).queryExecution.executedPlan + + val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } + val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + + assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + + dropTempTable("testLimit") + } - val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + val origThreshold = conf.autoBroadcastJoinThreshold - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + val simpleTypes = + NullType :: + BooleanType :: + ByteType :: + ShortType :: + IntegerType :: + LongType :: + FloatType :: + DoubleType :: + DecimalType(10, 5) :: + DecimalType.Unlimited :: + DateType :: + TimestampType :: + StringType :: + BinaryType :: Nil + + checkPlan(simpleTypes, newThreshold = 16434) + + val complexTypes = + ArrayType(DoubleType, true) :: + ArrayType(StringType, false) :: + MapType(IntegerType, StringType, true) :: + MapType(IntegerType, ArrayType(DoubleType), false) :: + StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", ArrayType(DoubleType), nullable = false), + StructField("c", DoubleType, nullable = false))) :: Nil + + checkPlan(complexTypes, newThreshold = 901617) setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) } From d181c2a1fc40746947b97799b12e7dd8c213fa9c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 20 Jan 2015 15:16:14 -0800 Subject: [PATCH 05/27] [SPARK-5323][SQL] Remove Row's Seq inheritance. Author: Reynold Xin Closes #4115 from rxin/row-seq and squashes the following commits: e33abd8 [Reynold Xin] Fixed compilation error. cceb650 [Reynold Xin] Python test fixes, and removal of WrapDynamic. 0334a52 [Reynold Xin] mkString. 9cdeb7d [Reynold Xin] Hive tests. 15681c2 [Reynold Xin] Fix more test cases. ea9023a [Reynold Xin] Fixed a catalyst test. c5e2cb5 [Reynold Xin] Minor patch up. b9cab7c [Reynold Xin] [SPARK-5323][SQL] Remove Row's Seq inheritance. --- .../main/scala/org/apache/spark/sql/Row.scala | 75 +++- .../spark/sql/catalyst/ScalaReflection.scala | 3 +- .../spark/sql/catalyst/dsl/package.scala | 3 - .../spark/sql/catalyst/expressions/Cast.scala | 3 +- .../sql/catalyst/expressions/Projection.scala | 310 +++++++------ .../expressions/SpecificMutableRow.scala | 4 +- .../catalyst/expressions/WrapDynamic.scala | 64 --- .../codegen/GenerateProjection.scala | 22 +- .../spark/sql/catalyst/expressions/rows.scala | 6 +- .../sql/catalyst/ScalaReflectionSuite.scala | 2 +- .../org/apache/spark/sql/SchemaRDD.scala | 19 - .../columnar/InMemoryColumnarTableScan.scala | 9 +- .../compression/compressionSchemes.scala | 2 +- .../spark/sql/execution/debug/package.scala | 2 +- .../spark/sql/execution/pythonUdfs.scala | 7 +- .../org/apache/spark/sql/json/JsonRDD.scala | 8 +- .../spark/sql/parquet/ParquetConverter.scala | 2 +- .../org/apache/spark/sql/DslQuerySuite.scala | 146 +++--- .../org/apache/spark/sql/JoinSuite.scala | 242 +++++----- .../org/apache/spark/sql/QueryTest.scala | 31 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 416 +++++++++--------- .../sql/ScalaReflectionRelationSuite.scala | 6 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 6 +- .../columnar/InMemoryColumnarQuerySuite.scala | 18 +- .../columnar/PartitionBatchPruningSuite.scala | 2 +- .../compression/BooleanBitSetSuite.scala | 2 +- .../apache/spark/sql/execution/TgfSuite.scala | 4 +- .../org/apache/spark/sql/json/JsonSuite.scala | 185 ++++---- .../sql/parquet/ParquetFilterSuite.scala | 91 ++-- .../spark/sql/parquet/ParquetIOSuite.scala | 12 +- .../spark/sql/parquet/ParquetQuerySuite.scala | 28 +- .../sql/parquet/ParquetQuerySuite2.scala | 2 +- .../spark/sql/sources/TableScanSuite.scala | 4 +- .../apache/spark/sql/hive/HiveContext.scala | 8 +- .../spark/sql/hive/HiveInspectors.scala | 20 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 2 +- .../spark/sql/hive/hiveWriterContainers.scala | 2 +- .../org/apache/spark/sql/QueryTest.scala | 48 +- .../spark/sql/hive/HiveInspectorSuite.scala | 8 +- .../sql/hive/InsertIntoHiveTableSuite.scala | 6 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 12 +- .../spark/sql/hive/StatisticsSuite.scala | 8 +- .../sql/hive/execution/HiveQuerySuite.scala | 34 +- .../sql/hive/execution/HiveUdfSuite.scala | 12 +- .../sql/hive/execution/SQLQuerySuite.scala | 18 +- .../spark/sql/parquet/HiveParquetSuite.scala | 2 +- .../spark/sql/parquet/parquetSuites.scala | 58 +-- 47 files changed, 1018 insertions(+), 956 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 208ec92987ac8..41bb4f012f2e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import scala.util.hashing.MurmurHash3 + import org.apache.spark.sql.catalyst.expressions.GenericRow @@ -32,7 +34,7 @@ object Row { * } * }}} */ - def unapplySeq(row: Row): Some[Seq[Any]] = Some(row) + def unapplySeq(row: Row): Some[Seq[Any]] = Some(row.toSeq) /** * This method can be used to construct a [[Row]] with the given values. @@ -43,6 +45,16 @@ object Row { * This method can be used to construct a [[Row]] from a [[Seq]] of values. */ def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray) + + def fromTuple(tuple: Product): Row = fromSeq(tuple.productIterator.toSeq) + + /** + * Merge multiple rows into a single row, one after another. + */ + def merge(rows: Row*): Row = { + // TODO: Improve the performance of this if used in performance critical part. + new GenericRow(rows.flatMap(_.toSeq).toArray) + } } @@ -103,7 +115,13 @@ object Row { * * @group row */ -trait Row extends Seq[Any] with Serializable { +trait Row extends Serializable { + /** Number of elements in the Row. */ + def size: Int = length + + /** Number of elements in the Row. */ + def length: Int + /** * Returns the value at position i. If the value is null, null is returned. The following * is a mapping between Spark SQL types and return types: @@ -291,12 +309,61 @@ trait Row extends Seq[Any] with Serializable { /** Returns true if there are any NULL values in this row. */ def anyNull: Boolean = { - val l = length + val len = length var i = 0 - while (i < l) { + while (i < len) { if (isNullAt(i)) { return true } i += 1 } false } + + override def equals(that: Any): Boolean = that match { + case null => false + case that: Row => + if (this.length != that.length) { + return false + } + var i = 0 + val len = this.length + while (i < len) { + if (apply(i) != that.apply(i)) { + return false + } + i += 1 + } + true + case _ => false + } + + override def hashCode: Int = { + // Using Scala's Seq hash code implementation. + var n = 0 + var h = MurmurHash3.seqSeed + val len = length + while (n < len) { + h = MurmurHash3.mix(h, apply(n).##) + n += 1 + } + MurmurHash3.finalizeHash(h, n) + } + + /* ---------------------- utility methods for Scala ---------------------- */ + + /** + * Return a Scala Seq representing the row. ELements are placed in the same order in the Seq. + */ + def toSeq: Seq[Any] + + /** Displays all elements of this sequence in a string (without a separator). */ + def mkString: String = toSeq.mkString + + /** Displays all elements of this sequence in a string using a separator string. */ + def mkString(sep: String): String = toSeq.mkString(sep) + + /** + * Displays all elements of this traversable or iterator in a string using + * start, end, and separator strings. + */ + def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d280db83b26f7..191d16fb10b5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -84,8 +84,9 @@ trait ScalaReflection { } def convertRowToScala(r: Row, schema: StructType): Row = { + // TODO: This is very slow!!! new GenericRow( - r.zip(schema.fields.map(_.dataType)) + r.toSeq.zip(schema.fields.map(_.dataType)) .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 26c855878d202..417659eed5957 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -272,9 +272,6 @@ package object dsl { def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) = Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) - def sfilter(dynamicUdf: (DynamicRow) => Boolean) = - Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan) - def sample( fraction: Double, withReplacement: Boolean = true, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 1a2133bbbcec7..ece5ee73618cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -407,7 +407,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val casts = from.fields.zip(to.fields).map { case (fromField, toField) => cast(fromField.dataType, toField.dataType) } - buildCast[Row](_, row => Row(row.zip(casts).map { + // TODO: This is very slow! + buildCast[Row](_, row => Row(row.toSeq.zip(casts).map { case (v, cast) => if (v == null) null else cast(v) }: _*)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index e7e81a21fdf03..db5d897ee569f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -105,45 +105,45 @@ class JoinedRow extends Row { this } - def iterator = row1.iterator ++ row2.iterator + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - def length = row1.length + row2.length + override def length = row1.length + row2.length - def apply(i: Int) = - if (i < row1.size) row1(i) else row2(i - row1.size) + override def apply(i: Int) = + if (i < row1.length) row1(i) else row2(i - row1.length) - def isNullAt(i: Int) = - if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + override def isNullAt(i: Int) = + if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - def getInt(i: Int): Int = - if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + override def getInt(i: Int): Int = + if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - def getLong(i: Int): Long = - if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + override def getLong(i: Int): Long = + if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - def getDouble(i: Int): Double = - if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + override def getDouble(i: Int): Double = + if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - def getBoolean(i: Int): Boolean = - if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + override def getBoolean(i: Int): Boolean = + if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - def getShort(i: Int): Short = - if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + override def getShort(i: Int): Short = + if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - def getByte(i: Int): Byte = - if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + override def getByte(i: Int): Byte = + if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - def getFloat(i: Int): Float = - if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + override def getFloat(i: Int): Float = + if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - def getString(i: Int): String = - if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getString(i: Int): String = + if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) override def getAs[T](i: Int): T = - if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - def copy() = { - val totalSize = row1.size + row2.size + override def copy() = { + val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { @@ -154,8 +154,16 @@ class JoinedRow extends Row { } override def toString() = { - val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) - s"[${row.mkString(",")}]" + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.mkString("[", ",", "]") + } else if (row2 eq null) { + row1.mkString("[", ",", "]") + } else { + mkString("[", ",", "]") + } } } @@ -197,45 +205,45 @@ class JoinedRow2 extends Row { this } - def iterator = row1.iterator ++ row2.iterator + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - def length = row1.length + row2.length + override def length = row1.length + row2.length - def apply(i: Int) = - if (i < row1.size) row1(i) else row2(i - row1.size) + override def apply(i: Int) = + if (i < row1.length) row1(i) else row2(i - row1.length) - def isNullAt(i: Int) = - if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + override def isNullAt(i: Int) = + if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - def getInt(i: Int): Int = - if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + override def getInt(i: Int): Int = + if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - def getLong(i: Int): Long = - if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + override def getLong(i: Int): Long = + if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - def getDouble(i: Int): Double = - if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + override def getDouble(i: Int): Double = + if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - def getBoolean(i: Int): Boolean = - if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + override def getBoolean(i: Int): Boolean = + if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - def getShort(i: Int): Short = - if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + override def getShort(i: Int): Short = + if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - def getByte(i: Int): Byte = - if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + override def getByte(i: Int): Byte = + if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - def getFloat(i: Int): Float = - if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + override def getFloat(i: Int): Float = + if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - def getString(i: Int): String = - if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getString(i: Int): String = + if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) override def getAs[T](i: Int): T = - if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - def copy() = { - val totalSize = row1.size + row2.size + override def copy() = { + val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { @@ -246,8 +254,16 @@ class JoinedRow2 extends Row { } override def toString() = { - val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) - s"[${row.mkString(",")}]" + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.mkString("[", ",", "]") + } else if (row2 eq null) { + row1.mkString("[", ",", "]") + } else { + mkString("[", ",", "]") + } } } @@ -283,45 +299,45 @@ class JoinedRow3 extends Row { this } - def iterator = row1.iterator ++ row2.iterator + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - def length = row1.length + row2.length + override def length = row1.length + row2.length - def apply(i: Int) = - if (i < row1.size) row1(i) else row2(i - row1.size) + override def apply(i: Int) = + if (i < row1.length) row1(i) else row2(i - row1.length) - def isNullAt(i: Int) = - if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + override def isNullAt(i: Int) = + if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - def getInt(i: Int): Int = - if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + override def getInt(i: Int): Int = + if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - def getLong(i: Int): Long = - if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + override def getLong(i: Int): Long = + if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - def getDouble(i: Int): Double = - if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + override def getDouble(i: Int): Double = + if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - def getBoolean(i: Int): Boolean = - if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + override def getBoolean(i: Int): Boolean = + if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - def getShort(i: Int): Short = - if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + override def getShort(i: Int): Short = + if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - def getByte(i: Int): Byte = - if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + override def getByte(i: Int): Byte = + if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - def getFloat(i: Int): Float = - if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + override def getFloat(i: Int): Float = + if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - def getString(i: Int): String = - if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getString(i: Int): String = + if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) override def getAs[T](i: Int): T = - if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - def copy() = { - val totalSize = row1.size + row2.size + override def copy() = { + val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { @@ -332,8 +348,16 @@ class JoinedRow3 extends Row { } override def toString() = { - val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) - s"[${row.mkString(",")}]" + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.mkString("[", ",", "]") + } else if (row2 eq null) { + row1.mkString("[", ",", "]") + } else { + mkString("[", ",", "]") + } } } @@ -369,45 +393,45 @@ class JoinedRow4 extends Row { this } - def iterator = row1.iterator ++ row2.iterator + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - def length = row1.length + row2.length + override def length = row1.length + row2.length - def apply(i: Int) = - if (i < row1.size) row1(i) else row2(i - row1.size) + override def apply(i: Int) = + if (i < row1.length) row1(i) else row2(i - row1.length) - def isNullAt(i: Int) = - if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + override def isNullAt(i: Int) = + if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - def getInt(i: Int): Int = - if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + override def getInt(i: Int): Int = + if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - def getLong(i: Int): Long = - if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + override def getLong(i: Int): Long = + if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - def getDouble(i: Int): Double = - if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + override def getDouble(i: Int): Double = + if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - def getBoolean(i: Int): Boolean = - if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + override def getBoolean(i: Int): Boolean = + if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - def getShort(i: Int): Short = - if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + override def getShort(i: Int): Short = + if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - def getByte(i: Int): Byte = - if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + override def getByte(i: Int): Byte = + if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - def getFloat(i: Int): Float = - if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + override def getFloat(i: Int): Float = + if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - def getString(i: Int): String = - if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getString(i: Int): String = + if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) override def getAs[T](i: Int): T = - if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - def copy() = { - val totalSize = row1.size + row2.size + override def copy() = { + val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { @@ -418,8 +442,16 @@ class JoinedRow4 extends Row { } override def toString() = { - val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) - s"[${row.mkString(",")}]" + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.mkString("[", ",", "]") + } else if (row2 eq null) { + row1.mkString("[", ",", "]") + } else { + mkString("[", ",", "]") + } } } @@ -455,45 +487,45 @@ class JoinedRow5 extends Row { this } - def iterator = row1.iterator ++ row2.iterator + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - def length = row1.length + row2.length + override def length = row1.length + row2.length - def apply(i: Int) = - if (i < row1.size) row1(i) else row2(i - row1.size) + override def apply(i: Int) = + if (i < row1.length) row1(i) else row2(i - row1.length) - def isNullAt(i: Int) = - if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + override def isNullAt(i: Int) = + if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - def getInt(i: Int): Int = - if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + override def getInt(i: Int): Int = + if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - def getLong(i: Int): Long = - if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + override def getLong(i: Int): Long = + if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - def getDouble(i: Int): Double = - if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + override def getDouble(i: Int): Double = + if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - def getBoolean(i: Int): Boolean = - if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + override def getBoolean(i: Int): Boolean = + if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - def getShort(i: Int): Short = - if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + override def getShort(i: Int): Short = + if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - def getByte(i: Int): Byte = - if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + override def getByte(i: Int): Byte = + if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - def getFloat(i: Int): Float = - if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + override def getFloat(i: Int): Float = + if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - def getString(i: Int): String = - if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getString(i: Int): String = + if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) override def getAs[T](i: Int): T = - if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - def copy() = { - val totalSize = row1.size + row2.size + override def copy() = { + val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { @@ -504,7 +536,15 @@ class JoinedRow5 extends Row { } override def toString() = { - val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) - s"[${row.mkString(",")}]" + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.mkString("[", ",", "]") + } else if (row2 eq null) { + row1.mkString("[", ",", "]") + } else { + mkString("[", ",", "]") + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 37d9f0ed5c79e..7434165f654f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -209,6 +209,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def length: Int = values.length + override def toSeq: Seq[Any] = values.map(_.boxed).toSeq + override def setNullAt(i: Int): Unit = { values(i).isNull = true } @@ -231,8 +233,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR if (value == null) setNullAt(ordinal) else values(ordinal).update(value) } - override def iterator: Iterator[Any] = values.map(_.boxed).iterator - override def setString(ordinal: Int, value: String) = update(ordinal, value) override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala deleted file mode 100644 index e2f5c7332d9ab..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import scala.language.dynamics - -import org.apache.spark.sql.types.DataType - -/** - * The data type representing [[DynamicRow]] values. - */ -case object DynamicType extends DataType { - - /** - * The default size of a value of the DynamicType is 4096 bytes. - */ - override def defaultSize: Int = 4096 -} - -/** - * Wrap a [[Row]] as a [[DynamicRow]]. - */ -case class WrapDynamic(children: Seq[Attribute]) extends Expression { - type EvaluatedType = DynamicRow - - def nullable = false - - def dataType = DynamicType - - override def eval(input: Row): DynamicRow = input match { - // Avoid copy for generic rows. - case g: GenericRow => new DynamicRow(children, g.values) - case otherRowType => new DynamicRow(children, otherRowType.toArray) - } -} - -/** - * DynamicRows use scala's Dynamic trait to emulate an ORM of in a dynamically typed language. - * Since the type of the column is not known at compile time, all attributes are converted to - * strings before being passed to the function. - */ -class DynamicRow(val schema: Seq[Attribute], values: Array[Any]) - extends GenericRow(values) with Dynamic { - - def selectDynamic(attributeName: String): String = { - val ordinal = schema.indexWhere(_.name == attributeName) - values(ordinal).toString - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index cc97cb4f50b69..69397a73a8880 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -77,14 +77,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { """.children : Seq[Tree] } - val iteratorFunction = { - val allColumns = (0 until expressions.size).map { i => - val iLit = ru.Literal(Constant(i)) - q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" - } - q"override def iterator = Iterator[Any](..$allColumns)" - } - val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)""" val applyFunction = { val cases = (0 until expressions.size).map { i => @@ -191,20 +183,26 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } """ + val allColumns = (0 until expressions.size).map { i => + val iLit = ru.Literal(Constant(i)) + q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" + } + val copyFunction = - q""" - override def copy() = new $genericRowType(this.toArray) - """ + q"override def copy() = new $genericRowType(Array[Any](..$allColumns))" + + val toSeqFunction = + q"override def toSeq: Seq[Any] = Seq(..$allColumns)" val classBody = nullFunctions ++ ( lengthDef +: - iteratorFunction +: applyFunction +: updateFunction +: equalsFunction +: hashCodeFunction +: copyFunction +: + toSeqFunction +: (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions)) val code = q""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index c22b8426841da..8df150e2f855f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -44,7 +44,7 @@ trait MutableRow extends Row { */ object EmptyRow extends Row { override def apply(i: Int): Any = throw new UnsupportedOperationException - override def iterator = Iterator.empty + override def toSeq = Seq.empty override def length = 0 override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException override def getInt(i: Int): Int = throw new UnsupportedOperationException @@ -70,7 +70,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { def this(size: Int) = this(new Array[Any](size)) - override def iterator = values.iterator + override def toSeq = values.toSeq override def length = values.length @@ -119,7 +119,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } // Custom hashCode function that matches the efficient code generated version. - override def hashCode(): Int = { + override def hashCode: Int = { var result: Int = 37 var i = 0 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 6df5db4c80f34..5138942a55daa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -244,7 +244,7 @@ class ScalaReflectionSuite extends FunSuite { test("convert PrimitiveData to catalyst") { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) - val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) + val convertedData = Row(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) val dataType = schemaFor[PrimitiveData].dataType assert(convertToCatalyst(data, dataType) === convertedData) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index ae4d8ba90c5bd..d1e21dffeb8c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -330,25 +330,6 @@ class SchemaRDD( sqlContext, Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)) - /** - * :: Experimental :: - * Filters tuples using a function over a `Dynamic` version of a given Row. DynamicRows use - * scala's Dynamic trait to emulate an ORM of in a dynamically typed language. Since the type of - * the column is not known at compile time, all attributes are converted to strings before - * being passed to the function. - * - * {{{ - * schemaRDD.where(r => r.firstName == "Bob" && r.lastName == "Smith") - * }}} - * - * @group Query - */ - @Experimental - def where(dynamicUdf: (DynamicRow) => Boolean) = - new SchemaRDD( - sqlContext, - Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan)) - /** * :: Experimental :: * Returns a sampled version of the underlying dataset. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 065fae3c83df1..11d5943fb427f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -21,7 +21,6 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -128,8 +127,7 @@ private[sql] case class InMemoryRelation( rowCount += 1 } - val stats = Row.fromSeq( - columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _)) + val stats = Row.merge(columnBuilders.map(_.columnStats.collectedStatistics) : _*) batchStats += stats CachedBatch(columnBuilders.map(_.build().array()), stats) @@ -271,9 +269,10 @@ private[sql] case class InMemoryColumnarTableScan( // Extract rows via column accessors new Iterator[Row] { + private[this] val rowLen = nextRow.length override def next() = { var i = 0 - while (i < nextRow.length) { + while (i < rowLen) { columnAccessors(i).extractTo(nextRow, i) i += 1 } @@ -297,7 +296,7 @@ private[sql] case class InMemoryColumnarTableScan( cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { def statsString = relation.partitionStatistics.schema - .zip(cachedBatch.stats) + .zip(cachedBatch.stats.toSeq) .map { case (a, s) => s"${a.name}: $s" } .mkString(", ") logInfo(s"Skipping partition based on stats $statsString") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 64673248394c6..68a5b1de7691b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -127,7 +127,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { while (from.hasRemaining) { columnType.extract(from, value, 0) - if (value.head == currentValue.head) { + if (value(0) == currentValue(0)) { currentRun += 1 } else { // Writes current run diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 46245cd5a1869..4d7e338e8ed13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -144,7 +144,7 @@ package object debug { case (null, _) => case (row: Row, StructType(fields)) => - row.zip(fields.map(_.dataType)).foreach { case(d,t) => typeCheck(d,t) } + row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } case (s: Seq[_], ArrayType(elemType, _)) => s.foreach(typeCheck(_, elemType)) case (m: Map[_, _], MapType(keyType, valueType, _)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 7ed64aad10d4e..b85021acc9d4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -116,9 +116,9 @@ object EvaluatePython { def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { case (null, _) => null - case (row: Seq[Any], struct: StructType) => + case (row: Row, struct: StructType) => val fields = struct.fields.map(field => field.dataType) - row.zip(fields).map { + row.toSeq.zip(fields).map { case (obj, dataType) => toJava(obj, dataType) }.toArray @@ -143,7 +143,8 @@ object EvaluatePython { * Convert Row into Java Array (for pickled into Python) */ def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = { - row.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray + // TODO: this is slow! + row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray } // Converts value to the type specified by the data type. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index db70a7eac72b9..9171939f7e8f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -458,16 +458,16 @@ private[sql] object JsonRDD extends Logging { gen.writeEndArray() case (MapType(kv,vv, _), v: Map[_,_]) => - gen.writeStartObject + gen.writeStartObject() v.foreach { p => gen.writeFieldName(p._1.toString) valWriter(vv,p._2) } - gen.writeEndObject + gen.writeEndObject() - case (StructType(ty), v: Seq[_]) => + case (StructType(ty), v: Row) => gen.writeStartObject() - ty.zip(v).foreach { + ty.zip(v.toSeq).foreach { case (_, null) => case (field, v) => gen.writeFieldName(field.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index b4aed04199129..9d9150246c8d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -66,7 +66,7 @@ private[sql] object CatalystConverter { // TODO: consider using Array[T] for arrays to avoid boxing of primitive types type ArrayScalaType[T] = Seq[T] - type StructScalaType[T] = Seq[T] + type StructScalaType[T] = Row type MapScalaType[K, V] = Map[K, V] protected[parquet] def createConverter( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 2bcfe28456997..afbfe214f1ce4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -45,28 +45,28 @@ class DslQuerySuite extends QueryTest { test("agg") { checkAnswer( testData2.groupBy('a)('a, sum('b)), - Seq((1,3),(2,3),(3,3)) + Seq(Row(1,3), Row(2,3), Row(3,3)) ) checkAnswer( testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)), - 9 + Row(9) ) checkAnswer( testData2.aggregate(sum('b)), - 9 + Row(9) ) } test("convert $\"attribute name\" into unresolved attribute") { checkAnswer( testData.where($"key" === 1).select($"value"), - Seq(Seq("1"))) + Row("1")) } test("convert Scala Symbol 'attrname into unresolved attribute") { checkAnswer( testData.where('key === 1).select('value), - Seq(Seq("1"))) + Row("1")) } test("select *") { @@ -78,61 +78,61 @@ class DslQuerySuite extends QueryTest { test("simple select") { checkAnswer( testData.where('key === 1).select('value), - Seq(Seq("1"))) + Row("1")) } test("select with functions") { checkAnswer( testData.select(sum('value), avg('value), count(1)), - Seq(Seq(5050.0, 50.5, 100))) + Row(5050.0, 50.5, 100)) checkAnswer( testData2.select('a + 'b, 'a < 'b), Seq( - Seq(2, false), - Seq(3, true), - Seq(3, false), - Seq(4, false), - Seq(4, false), - Seq(5, false))) + Row(2, false), + Row(3, true), + Row(3, false), + Row(4, false), + Row(4, false), + Row(5, false))) checkAnswer( testData2.select(sumDistinct('a)), - Seq(Seq(6))) + Row(6)) } test("global sorting") { checkAnswer( testData2.orderBy('a.asc, 'b.asc), - Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) + Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) checkAnswer( testData2.orderBy('a.asc, 'b.desc), - Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1))) + Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) checkAnswer( testData2.orderBy('a.desc, 'b.desc), - Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1))) + Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) checkAnswer( testData2.orderBy('a.desc, 'b.asc), - Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) + Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) checkAnswer( arrayData.orderBy('data.getItem(0).asc), - arrayData.collect().sortBy(_.data(0)).toSeq) + arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) checkAnswer( arrayData.orderBy('data.getItem(0).desc), - arrayData.collect().sortBy(_.data(0)).reverse.toSeq) + arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) checkAnswer( - mapData.orderBy('data.getItem(1).asc), - mapData.collect().sortBy(_.data(1)).toSeq) + arrayData.orderBy('data.getItem(1).asc), + arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) checkAnswer( - mapData.orderBy('data.getItem(1).desc), - mapData.collect().sortBy(_.data(1)).reverse.toSeq) + arrayData.orderBy('data.getItem(1).desc), + arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) } test("partition wide sorting") { @@ -147,19 +147,19 @@ class DslQuerySuite extends QueryTest { // (3, 2) checkAnswer( testData2.sortBy('a.asc, 'b.asc), - Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) + Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) checkAnswer( testData2.sortBy('a.asc, 'b.desc), - Seq((1,2), (1,1), (2,1), (2,2), (3,2), (3,1))) + Seq(Row(1,2), Row(1,1), Row(2,1), Row(2,2), Row(3,2), Row(3,1))) checkAnswer( testData2.sortBy('a.desc, 'b.desc), - Seq((2,1), (1,2), (1,1), (3,2), (3,1), (2,2))) + Seq(Row(2,1), Row(1,2), Row(1,1), Row(3,2), Row(3,1), Row(2,2))) checkAnswer( testData2.sortBy('a.desc, 'b.asc), - Seq((2,1), (1,1), (1,2), (3,1), (3,2), (2,2))) + Seq(Row(2,1), Row(1,1), Row(1,2), Row(3,1), Row(3,2), Row(2,2))) } test("limit") { @@ -169,11 +169,11 @@ class DslQuerySuite extends QueryTest { checkAnswer( arrayData.limit(1), - arrayData.take(1).toSeq) + arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) checkAnswer( mapData.limit(1), - mapData.take(1).toSeq) + mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) } test("SPARK-3395 limit distinct") { @@ -184,8 +184,8 @@ class DslQuerySuite extends QueryTest { .registerTempTable("onerow") checkAnswer( sql("select * from onerow inner join testData2 on onerow.a = testData2.a"), - (1, 1, 1, 1) :: - (1, 1, 1, 2) :: Nil) + Row(1, 1, 1, 1) :: + Row(1, 1, 1, 2) :: Nil) } test("SPARK-3858 generator qualifiers are discarded") { @@ -193,55 +193,55 @@ class DslQuerySuite extends QueryTest { arrayData.as('ad) .generate(Explode("data" :: Nil, 'data), alias = Some("ex")) .select("ex.data".attr), - Seq(1, 2, 3, 2, 3, 4).map(Seq(_))) + Seq(1, 2, 3, 2, 3, 4).map(Row(_))) } test("average") { checkAnswer( testData2.aggregate(avg('a)), - 2.0) + Row(2.0)) checkAnswer( testData2.aggregate(avg('a), sumDistinct('a)), // non-partial - (2.0, 6.0) :: Nil) + Row(2.0, 6.0) :: Nil) checkAnswer( decimalData.aggregate(avg('a)), - new java.math.BigDecimal(2.0)) + Row(new java.math.BigDecimal(2.0))) checkAnswer( decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial - (new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) checkAnswer( decimalData.aggregate(avg('a cast DecimalType(10, 2))), - new java.math.BigDecimal(2.0)) + Row(new java.math.BigDecimal(2.0))) checkAnswer( decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial - (new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) } test("null average") { checkAnswer( testData3.aggregate(avg('b)), - 2.0) + Row(2.0)) checkAnswer( testData3.aggregate(avg('b), countDistinct('b)), - (2.0, 1) :: Nil) + Row(2.0, 1)) checkAnswer( testData3.aggregate(avg('b), sumDistinct('b)), // non-partial - (2.0, 2.0) :: Nil) + Row(2.0, 2.0)) } test("zero average") { checkAnswer( emptyTableData.aggregate(avg('a)), - null) + Row(null)) checkAnswer( emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial - (null, null) :: Nil) + Row(null, null)) } test("count") { @@ -249,28 +249,28 @@ class DslQuerySuite extends QueryTest { checkAnswer( testData2.aggregate(count('a), sumDistinct('a)), // non-partial - (6, 6.0) :: Nil) + Row(6, 6.0)) } test("null count") { checkAnswer( testData3.groupBy('a)('a, count('b)), - Seq((1,0), (2, 1)) + Seq(Row(1,0), Row(2, 1)) ) checkAnswer( testData3.groupBy('a)('a, count('a + 'b)), - Seq((1,0), (2, 1)) + Seq(Row(1,0), Row(2, 1)) ) checkAnswer( testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)), - (2, 1, 2, 2, 1) :: Nil + Row(2, 1, 2, 2, 1) ) checkAnswer( testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial - (1, 1, 2) :: Nil + Row(1, 1, 2) ) } @@ -279,28 +279,28 @@ class DslQuerySuite extends QueryTest { checkAnswer( emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial - (0, null) :: Nil) + Row(0, null)) } test("zero sum") { checkAnswer( emptyTableData.aggregate(sum('a)), - null) + Row(null)) } test("zero sum distinct") { checkAnswer( emptyTableData.aggregate(sumDistinct('a)), - null) + Row(null)) } test("except") { checkAnswer( lowerCaseData.except(upperCaseData), - (1, "a") :: - (2, "b") :: - (3, "c") :: - (4, "d") :: Nil) + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer(lowerCaseData.except(lowerCaseData), Nil) checkAnswer(upperCaseData.except(upperCaseData), Nil) } @@ -308,10 +308,10 @@ class DslQuerySuite extends QueryTest { test("intersect") { checkAnswer( lowerCaseData.intersect(lowerCaseData), - (1, "a") :: - (2, "b") :: - (3, "c") :: - (4, "d") :: Nil) + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) } @@ -321,75 +321,75 @@ class DslQuerySuite extends QueryTest { checkAnswer( // SELECT *, foo(key, value) FROM testData testData.select(Star(None), foo.call('key, 'value)).limit(3), - (1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil + Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil ) } test("sqrt") { checkAnswer( testData.select(sqrt('key)).orderBy('key asc), - (1 to 100).map(n => Seq(math.sqrt(n))) + (1 to 100).map(n => Row(math.sqrt(n))) ) checkAnswer( testData.select(sqrt('value), 'key).orderBy('key asc, 'value asc), - (1 to 100).map(n => Seq(math.sqrt(n), n)) + (1 to 100).map(n => Row(math.sqrt(n), n)) ) checkAnswer( testData.select(sqrt(Literal(null))), - (1 to 100).map(_ => Seq(null)) + (1 to 100).map(_ => Row(null)) ) } test("abs") { checkAnswer( testData.select(abs('key)).orderBy('key asc), - (1 to 100).map(n => Seq(n)) + (1 to 100).map(n => Row(n)) ) checkAnswer( negativeData.select(abs('key)).orderBy('key desc), - (1 to 100).map(n => Seq(n)) + (1 to 100).map(n => Row(n)) ) checkAnswer( testData.select(abs(Literal(null))), - (1 to 100).map(_ => Seq(null)) + (1 to 100).map(_ => Row(null)) ) } test("upper") { checkAnswer( lowerCaseData.select(upper('l)), - ('a' to 'd').map(c => Seq(c.toString.toUpperCase())) + ('a' to 'd').map(c => Row(c.toString.toUpperCase())) ) checkAnswer( testData.select(upper('value), 'key), - (1 to 100).map(n => Seq(n.toString, n)) + (1 to 100).map(n => Row(n.toString, n)) ) checkAnswer( testData.select(upper(Literal(null))), - (1 to 100).map(n => Seq(null)) + (1 to 100).map(n => Row(null)) ) } test("lower") { checkAnswer( upperCaseData.select(lower('L)), - ('A' to 'F').map(c => Seq(c.toString.toLowerCase())) + ('A' to 'F').map(c => Row(c.toString.toLowerCase())) ) checkAnswer( testData.select(lower('value), 'key), - (1 to 100).map(n => Seq(n.toString, n)) + (1 to 100).map(n => Row(n.toString, n)) ) checkAnswer( testData.select(lower(Literal(null))), - (1 to 100).map(n => Seq(null)) + (1 to 100).map(n => Row(null)) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index e5ab16f9dd661..cd36da7751e83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -117,10 +117,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( upperCaseData.join(lowerCaseData, Inner).where('n === 'N), Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d") + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d") )) } @@ -128,10 +128,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)), Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d") + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d") )) } @@ -140,10 +140,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val y = testData2.where('a === 1).as('y) checkAnswer( x.join(y).where("x.a".attr === "y.a".attr), - (1,1,1,1) :: - (1,1,1,2) :: - (1,2,1,1) :: - (1,2,1,2) :: Nil + Row(1,1,1,1) :: + Row(1,1,1,2) :: + Row(1,2,1,1) :: + Row(1,2,1,2) :: Nil ) } @@ -163,54 +163,54 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr), testData.flatMap( - row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq) + row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } test("cartisian product join") { checkAnswer( testData3.join(testData3), - (1, null, 1, null) :: - (1, null, 2, 2) :: - (2, 2, 1, null) :: - (2, 2, 2, 2) :: Nil) + Row(1, null, 1, null) :: + Row(1, null, 2, 2) :: + Row(2, 2, 1, null) :: + Row(2, 2, 2, 2) :: Nil) } test("left outer join") { checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)), - (1, "A", 1, "a") :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) + Row(1, "A", 1, "a") :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)), - (1, "A", null, null) :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)), - (1, "A", null, null) :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)), - (1, "A", 1, "a") :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) + Row(1, "A", 1, "a") :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) // Make sure we are choosing left.outputPartitioning as the // outputPartitioning for the outer join operator. @@ -221,12 +221,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) |GROUP BY l.N """.stripMargin), - (1, 1) :: - (2, 1) :: - (3, 1) :: - (4, 1) :: - (5, 1) :: - (6, 1) :: Nil) + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: Nil) checkAnswer( sql( @@ -235,42 +235,42 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) |GROUP BY r.a """.stripMargin), - (null, 6) :: Nil) + Row(null, 6) :: Nil) } test("right outer join") { checkAnswer( lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)), - (1, "a", 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "a", 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) checkAnswer( lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)), - (null, null, 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(null, null, 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) checkAnswer( lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)), - (null, null, 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(null, null, 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) checkAnswer( lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)), - (1, "a", 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "a", 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) // Make sure we are choosing right.outputPartitioning as the // outputPartitioning for the outer join operator. @@ -281,7 +281,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY l.a """.stripMargin), - (null, 6) :: Nil) + Row(null, 6)) checkAnswer( sql( @@ -290,12 +290,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY r.N """.stripMargin), - (1, 1) :: - (2, 1) :: - (3, 1) :: - (4, 1) :: - (5, 1) :: - (6, 1) :: Nil) + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: Nil) } test("full outer join") { @@ -307,32 +307,32 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) checkAnswer( left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", null, null) :: - (null, null, 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", null, null) :: + Row(null, null, 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) checkAnswer( left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", null, null) :: - (null, null, 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", null, null) :: + Row(null, null, 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. checkAnswer( @@ -342,7 +342,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY l.a """.stripMargin), - (null, 10) :: Nil) + Row(null, 10)) checkAnswer( sql( @@ -351,13 +351,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY r.N """.stripMargin), - (1, 1) :: - (2, 1) :: - (3, 1) :: - (4, 1) :: - (5, 1) :: - (6, 1) :: - (null, 4) :: Nil) + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: + Row(null, 4) :: Nil) checkAnswer( sql( @@ -366,13 +366,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) |GROUP BY l.N """.stripMargin), - (1, 1) :: - (2, 1) :: - (3, 1) :: - (4, 1) :: - (5, 1) :: - (6, 1) :: - (null, 4) :: Nil) + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: + Row(null, 4) :: Nil) checkAnswer( sql( @@ -381,7 +381,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) |GROUP BY r.a """.stripMargin), - (null, 10) :: Nil) + Row(null, 10)) } test("broadcasted left semi join operator selection") { @@ -412,12 +412,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("left semi join") { val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(rdd, - (1, 1) :: - (1, 2) :: - (2, 1) :: - (2, 2) :: - (3, 1) :: - (3, 2) :: Nil) + Row(1, 1) :: + Row(1, 2) :: + Row(2, 1) :: + Row(2, 2) :: + Row(3, 1) :: + Row(3, 2) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 68ddecc7f610d..42a21c148df53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -47,26 +47,17 @@ class QueryTest extends PlanTest { * @param rdd the [[SchemaRDD]] to be executed * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. */ - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Any): Unit = { - val convertedAnswer = expectedAnswer match { - case s: Seq[_] if s.isEmpty => s - case s: Seq[_] if s.head.isInstanceOf[Product] && - !s.head.isInstanceOf[Seq[_]] => s.map(_.asInstanceOf[Product].productIterator.toIndexedSeq) - case s: Seq[_] => s - case singleItem => Seq(Seq(singleItem)) - } - + protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = { val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty - def prepareAnswer(answer: Seq[Any]): Seq[Any] = { + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). - val converted = answer.map { - case s: Seq[_] => s.map { + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { case d: java.math.BigDecimal => BigDecimal(d) case o => o - } - case o => o + }) } if (!isSorted) converted.sortBy(_.toString) else converted } @@ -82,7 +73,7 @@ class QueryTest extends PlanTest { """.stripMargin) } - if (prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) { + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { fail(s""" |Results do not match for query: |${rdd.logicalPlan} @@ -92,15 +83,19 @@ class QueryTest extends PlanTest { |${rdd.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${convertedAnswer.size} ==" +: - prepareAnswer(convertedAnswer).map(_.toString), + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString), s"== Spark Answer - ${sparkAnswer.size} ==" +: prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} """.stripMargin) } } - def sqlTest(sqlString: String, expectedAnswer: Any)(implicit sqlContext: SQLContext): Unit = { + protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = { + checkAnswer(rdd, Seq(expectedAnswer)) + } + + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { test(sqlString) { checkAnswer(sqlContext.sql(sqlString), expectedAnswer) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 54fabc5c915fb..03b44ca1d6695 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -46,7 +46,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { checkAnswer( sql("SELECT a FROM testData2 SORT BY a"), - Seq(1, 1, 2 ,2 ,3 ,3).map(Seq(_)) + Seq(1, 1, 2 ,2 ,3 ,3).map(Row(_)) ) } @@ -70,13 +70,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3176 Added Parser of SQL ABS()") { checkAnswer( sql("SELECT ABS(-1.3)"), - 1.3) + Row(1.3)) checkAnswer( sql("SELECT ABS(0.0)"), - 0.0) + Row(0.0)) checkAnswer( sql("SELECT ABS(2.5)"), - 2.5) + Row(2.5)) } test("aggregation with codegen") { @@ -89,13 +89,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3176 Added Parser of SQL LAST()") { checkAnswer( sql("SELECT LAST(n) FROM lowerCaseData"), - 4) + Row(4)) } test("SPARK-2041 column name equals tablename") { checkAnswer( sql("SELECT tableName FROM tableName"), - "test") + Row("test")) } test("SQRT") { @@ -115,40 +115,40 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-2407 Added Parser of SQL SUBSTR()") { checkAnswer( sql("SELECT substr(tableName, 1, 2) FROM tableName"), - "te") + Row("te")) checkAnswer( sql("SELECT substr(tableName, 3) FROM tableName"), - "st") + Row("st")) checkAnswer( sql("SELECT substring(tableName, 1, 2) FROM tableName"), - "te") + Row("te")) checkAnswer( sql("SELECT substring(tableName, 3) FROM tableName"), - "st") + Row("st")) } test("SPARK-3173 Timestamp support in the parser") { checkAnswer(sql( "SELECT time FROM timestamps WHERE time=CAST('1970-01-01 00:00:00.001' AS TIMESTAMP)"), - Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))) + Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) checkAnswer(sql( "SELECT time FROM timestamps WHERE time='1970-01-01 00:00:00.001'"), - Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))) + Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) checkAnswer(sql( "SELECT time FROM timestamps WHERE '1970-01-01 00:00:00.001'=time"), - Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))) + Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) checkAnswer(sql( """SELECT time FROM timestamps WHERE time<'1970-01-01 00:00:00.003' AND time>'1970-01-01 00:00:00.001'"""), - Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))) + Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002"))) checkAnswer(sql( "SELECT time FROM timestamps WHERE time IN ('1970-01-01 00:00:00.001','1970-01-01 00:00:00.002')"), - Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")), - Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))) + Seq(Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")), + Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))) checkAnswer(sql( "SELECT time FROM timestamps WHERE time='123'"), @@ -158,13 +158,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("index into array") { checkAnswer( sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"), - arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq) + arrayData.map(d => Row(d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect()) } test("left semi greater than predicate") { checkAnswer( sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), - Seq((3,1), (3,2)) + Seq(Row(3,1), Row(3,2)) ) } @@ -173,7 +173,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql( "SELECT nestedData, nestedData[0][0], nestedData[0][0] + nestedData[0][1] FROM arrayData"), arrayData.map(d => - (d.nestedData, + Row(d.nestedData, d.nestedData(0)(0), d.nestedData(0)(0) + d.nestedData(0)(1))).collect().toSeq) } @@ -181,13 +181,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("agg") { checkAnswer( sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), - Seq((1,3),(2,3),(3,3))) + Seq(Row(1,3), Row(2,3), Row(3,3))) } test("aggregates with nulls") { checkAnswer( sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), - (1, 3, 2, 6, 3) :: Nil + Row(1, 3, 2, 6, 3) ) } @@ -200,29 +200,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("simple select") { checkAnswer( sql("SELECT value FROM testData WHERE key = 1"), - Seq(Seq("1"))) + Row("1")) } def sortTest() = { checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), - Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) + Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"), - Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1))) + Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"), - Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1))) + Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"), - Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) + Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) checkAnswer( sql("SELECT b FROM binaryData ORDER BY a ASC"), - (1 to 5).map(Row(_)).toSeq) + (1 to 5).map(Row(_))) checkAnswer( sql("SELECT b FROM binaryData ORDER BY a DESC"), @@ -230,19 +230,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT * FROM arrayData ORDER BY data[0] ASC"), - arrayData.collect().sortBy(_.data(0)).toSeq) + arrayData.collect().sortBy(_.data(0)).map(Row.fromTuple).toSeq) checkAnswer( sql("SELECT * FROM arrayData ORDER BY data[0] DESC"), - arrayData.collect().sortBy(_.data(0)).reverse.toSeq) + arrayData.collect().sortBy(_.data(0)).reverse.map(Row.fromTuple).toSeq) checkAnswer( sql("SELECT * FROM mapData ORDER BY data[1] ASC"), - mapData.collect().sortBy(_.data(1)).toSeq) + mapData.collect().sortBy(_.data(1)).map(Row.fromTuple).toSeq) checkAnswer( sql("SELECT * FROM mapData ORDER BY data[1] DESC"), - mapData.collect().sortBy(_.data(1)).reverse.toSeq) + mapData.collect().sortBy(_.data(1)).reverse.map(Row.fromTuple).toSeq) } test("sorting") { @@ -266,94 +266,94 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT * FROM arrayData LIMIT 1"), - arrayData.collect().take(1).toSeq) + arrayData.collect().take(1).map(Row.fromTuple).toSeq) checkAnswer( sql("SELECT * FROM mapData LIMIT 1"), - mapData.collect().take(1).toSeq) + mapData.collect().take(1).map(Row.fromTuple).toSeq) } test("from follow multiple brackets") { checkAnswer(sql( "select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1"), - 1 + Row(1) ) checkAnswer(sql( "select key from (select * from testData) x limit 1"), - 1 + Row(1) ) checkAnswer(sql( "select key from (select * from testData limit 1 union all select * from testData limit 1) x limit 1"), - 1 + Row(1) ) } test("average") { checkAnswer( sql("SELECT AVG(a) FROM testData2"), - 2.0) + Row(2.0)) } test("average overflow") { checkAnswer( sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), - Seq((2147483645.0,1),(2.0,2))) + Seq(Row(2147483645.0,1), Row(2.0,2))) } test("count") { checkAnswer( sql("SELECT COUNT(*) FROM testData2"), - testData2.count()) + Row(testData2.count())) } test("count distinct") { checkAnswer( sql("SELECT COUNT(DISTINCT b) FROM testData2"), - 2) + Row(2)) } test("approximate count distinct") { checkAnswer( sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"), - 3) + Row(3)) } test("approximate count distinct with user provided standard deviation") { checkAnswer( sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"), - 3) + Row(3)) } test("null count") { checkAnswer( sql("SELECT a, COUNT(b) FROM testData3 GROUP BY a"), - Seq((1, 0), (2, 1))) + Seq(Row(1, 0), Row(2, 1))) checkAnswer( sql("SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), - (2, 1, 2, 2, 1) :: Nil) + Row(2, 1, 2, 2, 1)) } test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d"))) + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d"))) } test("inner join ON, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON n = N"), Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d"))) + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d"))) } test("inner join, where, multiple matches") { @@ -363,10 +363,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | (SELECT * FROM testData2 WHERE a = 1) x JOIN | (SELECT * FROM testData2 WHERE a = 1) y |WHERE x.a = y.a""".stripMargin), - (1,1,1,1) :: - (1,1,1,2) :: - (1,2,1,1) :: - (1,2,1,2) :: Nil) + Row(1,1,1,1) :: + Row(1,1,1,2) :: + Row(1,2,1,1) :: + Row(1,2,1,2) :: Nil) } test("inner join, no matches") { @@ -397,38 +397,38 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | SELECT * FROM testData) y |WHERE x.key = y.key""".stripMargin), testData.flatMap( - row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq) + row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } ignore("cartesian product join") { checkAnswer( testData3.join(testData3), - (1, null, 1, null) :: - (1, null, 2, 2) :: - (2, 2, 1, null) :: - (2, 2, 2, 2) :: Nil) + Row(1, null, 1, null) :: + Row(1, null, 2, 2) :: + Row(2, 2, 1, null) :: + Row(2, 2, 2, 2) :: Nil) } test("left outer join") { checkAnswer( sql("SELECT * FROM upperCaseData LEFT OUTER JOIN lowerCaseData ON n = N"), - (1, "A", 1, "a") :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) + Row(1, "A", 1, "a") :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) } test("right outer join") { checkAnswer( sql("SELECT * FROM lowerCaseData RIGHT OUTER JOIN upperCaseData ON n = N"), - (1, "a", 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "a", 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) } test("full outer join") { @@ -440,12 +440,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | (SELECT * FROM upperCaseData WHERE N >= 3) rightTable | ON leftTable.N = rightTable.N """.stripMargin), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", 3, "C") :: + Row (4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) } test("SPARK-3349 partitioning after limit") { @@ -457,12 +457,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .registerTempTable("subset2") checkAnswer( sql("SELECT * FROM lowerCaseData INNER JOIN subset1 ON subset1.n = lowerCaseData.n"), - (3, "c", 3) :: - (4, "d", 4) :: Nil) + Row(3, "c", 3) :: + Row(4, "d", 4) :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"), - (1, "a", 1) :: - (2, "b", 2) :: Nil) + Row(1, "a", 1) :: + Row(2, "b", 2) :: Nil) } test("mixed-case keywords") { @@ -474,28 +474,28 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | (sElEcT * FROM upperCaseData whERe N >= 3) rightTable | oN leftTable.N = rightTable.N """.stripMargin), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) } test("select with table name as qualifier") { checkAnswer( sql("SELECT testData.value FROM testData WHERE testData.key = 1"), - Seq(Seq("1"))) + Row("1")) } test("inner join ON with table name as qualifier") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON lowerCaseData.n = upperCaseData.N"), Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d"))) + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d"))) } test("qualified select with inner join ON with table name as qualifier") { @@ -503,72 +503,72 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT upperCaseData.N, upperCaseData.L FROM upperCaseData JOIN lowerCaseData " + "ON lowerCaseData.n = upperCaseData.N"), Seq( - (1, "A"), - (2, "B"), - (3, "C"), - (4, "D"))) + Row(1, "A"), + Row(2, "B"), + Row(3, "C"), + Row(4, "D"))) } test("system function upper()") { checkAnswer( sql("SELECT n,UPPER(l) FROM lowerCaseData"), Seq( - (1, "A"), - (2, "B"), - (3, "C"), - (4, "D"))) + Row(1, "A"), + Row(2, "B"), + Row(3, "C"), + Row(4, "D"))) checkAnswer( sql("SELECT n, UPPER(s) FROM nullStrings"), Seq( - (1, "ABC"), - (2, "ABC"), - (3, null))) + Row(1, "ABC"), + Row(2, "ABC"), + Row(3, null))) } test("system function lower()") { checkAnswer( sql("SELECT N,LOWER(L) FROM upperCaseData"), Seq( - (1, "a"), - (2, "b"), - (3, "c"), - (4, "d"), - (5, "e"), - (6, "f"))) + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "d"), + Row(5, "e"), + Row(6, "f"))) checkAnswer( sql("SELECT n, LOWER(s) FROM nullStrings"), Seq( - (1, "abc"), - (2, "abc"), - (3, null))) + Row(1, "abc"), + Row(2, "abc"), + Row(3, null))) } test("UNION") { checkAnswer( sql("SELECT * FROM lowerCaseData UNION SELECT * FROM upperCaseData"), - (1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") :: - (4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil) + Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") :: + Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData UNION SELECT * FROM lowerCaseData"), - (1, "a") :: (2, "b") :: (3, "c") :: (4, "d") :: Nil) + Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Row(4, "d") :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData UNION ALL SELECT * FROM lowerCaseData"), - (1, "a") :: (1, "a") :: (2, "b") :: (2, "b") :: (3, "c") :: (3, "c") :: - (4, "d") :: (4, "d") :: Nil) + Row(1, "a") :: Row(1, "a") :: Row(2, "b") :: Row(2, "b") :: Row(3, "c") :: Row(3, "c") :: + Row(4, "d") :: Row(4, "d") :: Nil) } test("UNION with column mismatches") { // Column name mismatches are allowed. checkAnswer( sql("SELECT n,l FROM lowerCaseData UNION SELECT N as x1, L as x2 FROM upperCaseData"), - (1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") :: - (4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil) + Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") :: + Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil) // Column type mismatches are not allowed, forcing a type coercion. checkAnswer( sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"), - ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Tuple1(_))) + ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_))) // Column type mismatches where a coercion is not possible, in this case between integer // and array types, trigger a TreeNodeException. intercept[TreeNodeException[_]] { @@ -579,10 +579,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("EXCEPT") { checkAnswer( sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData"), - (1, "a") :: - (2, "b") :: - (3, "c") :: - (4, "d") :: Nil) + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData"), Nil) checkAnswer( @@ -592,10 +592,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("INTERSECT") { checkAnswer( sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM lowerCaseData"), - (1, "a") :: - (2, "b") :: - (3, "c") :: - (4, "d") :: Nil) + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM upperCaseData"), Nil) } @@ -613,25 +613,25 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Seq(Seq(s"$testKey=$testVal")) + Row(s"$testKey=$testVal") ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), Seq( - Seq(s"$testKey=$testVal"), - Seq(s"${testKey + testKey}=${testVal + testVal}")) + Row(s"$testKey=$testVal"), + Row(s"${testKey + testKey}=${testVal + testVal}")) ) // "set key" checkAnswer( sql(s"SET $testKey"), - Seq(Seq(s"$testKey=$testVal")) + Row(s"$testKey=$testVal") ) checkAnswer( sql(s"SET $nonexistentKey"), - Seq(Seq(s"$nonexistentKey=")) + Row(s"$nonexistentKey=") ) conf.clear() } @@ -655,17 +655,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { schemaRDD1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), - (1, "A1", true, null) :: - (2, "B2", false, null) :: - (3, "C3", true, null) :: - (4, "D4", true, 2147483644) :: Nil) + Row(1, "A1", true, null) :: + Row(2, "B2", false, null) :: + Row(3, "C3", true, null) :: + Row(4, "D4", true, 2147483644) :: Nil) checkAnswer( sql("SELECT f1, f4 FROM applySchema1"), - (1, null) :: - (2, null) :: - (3, null) :: - (4, 2147483644) :: Nil) + Row(1, null) :: + Row(2, null) :: + Row(3, null) :: + Row(4, 2147483644) :: Nil) val schema2 = StructType( StructField("f1", StructType( @@ -685,17 +685,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { schemaRDD2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), - (Seq(1, true), Map("A1" -> null)) :: - (Seq(2, false), Map("B2" -> null)) :: - (Seq(3, true), Map("C3" -> null)) :: - (Seq(4, true), Map("D4" -> 2147483644)) :: Nil) + Row(Row(1, true), Map("A1" -> null)) :: + Row(Row(2, false), Map("B2" -> null)) :: + Row(Row(3, true), Map("C3" -> null)) :: + Row(Row(4, true), Map("D4" -> 2147483644)) :: Nil) checkAnswer( sql("SELECT f1.f11, f2['D4'] FROM applySchema2"), - (1, null) :: - (2, null) :: - (3, null) :: - (4, 2147483644) :: Nil) + Row(1, null) :: + Row(2, null) :: + Row(3, null) :: + Row(4, 2147483644) :: Nil) // The value of a MapType column can be a mutable map. val rowRDD3 = unparsedStrings.map { r => @@ -711,26 +711,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), - (1, null) :: - (2, null) :: - (3, null) :: - (4, 2147483644) :: Nil) + Row(1, null) :: + Row(2, null) :: + Row(3, null) :: + Row(4, 2147483644) :: Nil) } test("SPARK-3423 BETWEEN") { checkAnswer( sql("SELECT key, value FROM testData WHERE key BETWEEN 5 and 7"), - Seq((5, "5"), (6, "6"), (7, "7")) + Seq(Row(5, "5"), Row(6, "6"), Row(7, "7")) ) checkAnswer( sql("SELECT key, value FROM testData WHERE key BETWEEN 7 and 7"), - Seq((7, "7")) + Row(7, "7") ) checkAnswer( sql("SELECT key, value FROM testData WHERE key BETWEEN 9 and 7"), - Seq() + Nil ) } @@ -738,7 +738,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // TODO Ensure true/false string letter casing is consistent with Hive in all cases. checkAnswer( sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), - ("true", "false") :: Nil) + Row("true", "false")) } test("metadata is propagated correctly") { @@ -768,17 +768,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3371 Renaming a function expression with group by gives error") { udf.register("len", (s: String) => s.length) checkAnswer( - sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1) + sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), + Row(1)) } test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") { checkAnswer( - sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), 1) + sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), + Row(1)) } test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") { checkAnswer( - sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1) + sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), + Row(1)) } test("throw errors for non-aggregate attributes with aggregation") { @@ -808,130 +811,131 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("Test to check we can use Long.MinValue") { checkAnswer( - sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Long.MinValue + sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Row(Long.MinValue) ) checkAnswer( - sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"), (1 to 100).map(Row(_)).toSeq + sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"), + (1 to 100).map(Row(_)).toSeq ) } test("Floating point number format") { checkAnswer( - sql("SELECT 0.3"), 0.3 + sql("SELECT 0.3"), Row(0.3) ) checkAnswer( - sql("SELECT -0.8"), -0.8 + sql("SELECT -0.8"), Row(-0.8) ) checkAnswer( - sql("SELECT .5"), 0.5 + sql("SELECT .5"), Row(0.5) ) checkAnswer( - sql("SELECT -.18"), -0.18 + sql("SELECT -.18"), Row(-0.18) ) } test("Auto cast integer type") { checkAnswer( - sql(s"SELECT ${Int.MaxValue + 1L}"), Int.MaxValue + 1L + sql(s"SELECT ${Int.MaxValue + 1L}"), Row(Int.MaxValue + 1L) ) checkAnswer( - sql(s"SELECT ${Int.MinValue - 1L}"), Int.MinValue - 1L + sql(s"SELECT ${Int.MinValue - 1L}"), Row(Int.MinValue - 1L) ) checkAnswer( - sql("SELECT 9223372036854775808"), new java.math.BigDecimal("9223372036854775808") + sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808")) ) checkAnswer( - sql("SELECT -9223372036854775809"), new java.math.BigDecimal("-9223372036854775809") + sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809")) ) } test("Test to check we can apply sign to expression") { checkAnswer( - sql("SELECT -100"), -100 + sql("SELECT -100"), Row(-100) ) checkAnswer( - sql("SELECT +230"), 230 + sql("SELECT +230"), Row(230) ) checkAnswer( - sql("SELECT -5.2"), -5.2 + sql("SELECT -5.2"), Row(-5.2) ) checkAnswer( - sql("SELECT +6.8"), 6.8 + sql("SELECT +6.8"), Row(6.8) ) checkAnswer( - sql("SELECT -key FROM testData WHERE key = 2"), -2 + sql("SELECT -key FROM testData WHERE key = 2"), Row(-2) ) checkAnswer( - sql("SELECT +key FROM testData WHERE key = 3"), 3 + sql("SELECT +key FROM testData WHERE key = 3"), Row(3) ) checkAnswer( - sql("SELECT -(key + 1) FROM testData WHERE key = 1"), -2 + sql("SELECT -(key + 1) FROM testData WHERE key = 1"), Row(-2) ) checkAnswer( - sql("SELECT - key + 1 FROM testData WHERE key = 10"), -9 + sql("SELECT - key + 1 FROM testData WHERE key = 10"), Row(-9) ) checkAnswer( - sql("SELECT +(key + 5) FROM testData WHERE key = 5"), 10 + sql("SELECT +(key + 5) FROM testData WHERE key = 5"), Row(10) ) checkAnswer( - sql("SELECT -MAX(key) FROM testData"), -100 + sql("SELECT -MAX(key) FROM testData"), Row(-100) ) checkAnswer( - sql("SELECT +MAX(key) FROM testData"), 100 + sql("SELECT +MAX(key) FROM testData"), Row(100) ) checkAnswer( - sql("SELECT - (-10)"), 10 + sql("SELECT - (-10)"), Row(10) ) checkAnswer( - sql("SELECT + (-key) FROM testData WHERE key = 32"), -32 + sql("SELECT + (-key) FROM testData WHERE key = 32"), Row(-32) ) checkAnswer( - sql("SELECT - (+Max(key)) FROM testData"), -100 + sql("SELECT - (+Max(key)) FROM testData"), Row(-100) ) checkAnswer( - sql("SELECT - - 3"), 3 + sql("SELECT - - 3"), Row(3) ) checkAnswer( - sql("SELECT - + 20"), -20 + sql("SELECT - + 20"), Row(-20) ) checkAnswer( - sql("SELEcT - + 45"), -45 + sql("SELEcT - + 45"), Row(-45) ) checkAnswer( - sql("SELECT + + 100"), 100 + sql("SELECT + + 100"), Row(100) ) checkAnswer( - sql("SELECT - - Max(key) FROM testData"), 100 + sql("SELECT - - Max(key) FROM testData"), Row(100) ) checkAnswer( - sql("SELECT + - key FROM testData WHERE key = 33"), -33 + sql("SELECT + - key FROM testData WHERE key = 33"), Row(-33) ) } @@ -943,7 +947,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { |JOIN testData b ON a.key = b.key |JOIN testData c ON a.key = c.key """.stripMargin), - (1 to 100).map(i => Seq(i, i, i))) + (1 to 100).map(i => Row(i, i, i))) } test("SPARK-3483 Special chars in column names") { @@ -953,19 +957,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3814 Support Bitwise & operator") { - checkAnswer(sql("SELECT key&1 FROM testData WHERE key = 1 "), 1) + checkAnswer(sql("SELECT key&1 FROM testData WHERE key = 1 "), Row(1)) } test("SPARK-3814 Support Bitwise | operator") { - checkAnswer(sql("SELECT key|0 FROM testData WHERE key = 1 "), 1) + checkAnswer(sql("SELECT key|0 FROM testData WHERE key = 1 "), Row(1)) } test("SPARK-3814 Support Bitwise ^ operator") { - checkAnswer(sql("SELECT key^0 FROM testData WHERE key = 1 "), 1) + checkAnswer(sql("SELECT key^0 FROM testData WHERE key = 1 "), Row(1)) } test("SPARK-3814 Support Bitwise ~ operator") { - checkAnswer(sql("SELECT ~key FROM testData WHERE key = 1 "), -2) + checkAnswer(sql("SELECT ~key FROM testData WHERE key = 1 "), Row(-2)) } test("SPARK-4120 Join of multiple tables does not work in SparkSQL") { @@ -975,40 +979,40 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { |FROM testData a,testData b,testData c |where a.key = b.key and a.key = c.key """.stripMargin), - (1 to 100).map(i => Seq(i, i, i))) + (1 to 100).map(i => Row(i, i, i))) } test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { checkAnswer(sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"), - (11 to 100).map(i => Seq(i))) + (11 to 100).map(i => Row(i))) } test("SPARK-4207 Query which has syntax like 'not like' is not working in Spark SQL") { checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"), - (1 to 99).map(i => Seq(i))) + (1 to 99).map(i => Row(i))) } test("SPARK-4322 Grouping field with struct field as sub expression") { jsonRDD(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data") - checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), 1) + checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) dropTempTable("data") jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") - checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), 2) + checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) dropTempTable("data") } test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { checkAnswer( sql("SELECT a + b FROM testData2 ORDER BY a"), - Seq(2, 3, 3 ,4 ,4 ,5).map(Seq(_)) + Seq(2, 3, 3 ,4 ,4 ,5).map(Row(_)) ) } test("oder by asc by default when not specify ascending and descending") { checkAnswer( sql("SELECT a, b FROM testData2 ORDER BY a desc, b"), - Seq((3, 1), (3, 2), (2, 1), (2,2), (1, 1), (1, 2)) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2,2), Row(1, 1), Row(1, 2)) ) } @@ -1021,13 +1025,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { rdd2.registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), - (1 to 2).map(i => Seq(i))) + (1 to 2).map(i => Row(i))) } test("Multi-column COUNT(DISTINCT ...)") { val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.registerTempTable("distinctData") - checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), 2) + checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index ee381da491054..a015884bae282 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -82,7 +82,7 @@ class ScalaReflectionRelationSuite extends FunSuite { rdd.registerTempTable("reflectData") assert(sql("SELECT * FROM reflectData").collect().head === - Seq("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, + Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))) } @@ -91,7 +91,7 @@ class ScalaReflectionRelationSuite extends FunSuite { val rdd = sparkContext.parallelize(data :: Nil) rdd.registerTempTable("reflectNullData") - assert(sql("SELECT * FROM reflectNullData").collect().head === Seq.fill(7)(null)) + assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } test("query case class RDD with Nones") { @@ -99,7 +99,7 @@ class ScalaReflectionRelationSuite extends FunSuite { val rdd = sparkContext.parallelize(data :: Nil) rdd.registerTempTable("reflectOptionalData") - assert(sql("SELECT * FROM reflectOptionalData").collect().head === Seq.fill(7)(null)) + assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } // Equality is broken for Arrays, so we test that separately. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 9be0b38e689ff..be2b34de077c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -42,8 +42,8 @@ class ColumnStatsSuite extends FunSuite { test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => - assert(actual === expected) + columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + case (actual, expected) => assert(actual === expected) } } @@ -54,7 +54,7 @@ class ColumnStatsSuite extends FunSuite { val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) - val values = rows.take(10).map(_.head.asInstanceOf[T#JvmType]) + val values = rows.take(10).map(_(0).asInstanceOf[T#JvmType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]] val stats = columnStats.collectedStatistics diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index d94729ba92360..e61f3c39631da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -49,7 +49,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key - }.toSeq) + }.map(Row.fromTuple)) } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { @@ -63,49 +63,49 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("SPARK-1678 regression: compression must not lose repeated values") { checkAnswer( sql("SELECT * FROM repeatedData"), - repeatedData.collect().toSeq) + repeatedData.collect().toSeq.map(Row.fromTuple)) cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), - repeatedData.collect().toSeq) + repeatedData.collect().toSeq.map(Row.fromTuple)) } test("with null values") { checkAnswer( sql("SELECT * FROM nullableRepeatedData"), - nullableRepeatedData.collect().toSeq) + nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), - nullableRepeatedData.collect().toSeq) + nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) } test("SPARK-2729 regression: timestamp data type") { checkAnswer( sql("SELECT time FROM timestamps"), - timestamps.collect().toSeq) + timestamps.collect().toSeq.map(Row.fromTuple)) cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), - timestamps.collect().toSeq) + timestamps.collect().toSeq.map(Row.fromTuple)) } test("SPARK-3320 regression: batched column buffer building should work with empty partitions") { checkAnswer( sql("SELECT * FROM withEmptyParts"), - withEmptyParts.collect().toSeq) + withEmptyParts.collect().toSeq.map(Row.fromTuple)) cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), - withEmptyParts.collect().toSeq) + withEmptyParts.collect().toSeq.map(Row.fromTuple)) } test("SPARK-4182 Caching complex types") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 592cafbbdc203..c3a3f8ddc3ebf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -108,7 +108,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be val queryExecution = schemaRdd.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { - schemaRdd.collect().map(_.head).toArray + schemaRdd.collect().map(_(0)).toArray } val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index d9e488e0ffd16..8b518f094174c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -34,7 +34,7 @@ class BooleanBitSetSuite extends FunSuite { val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN)) - val values = rows.map(_.head) + val values = rows.map(_(0)) rows.foreach(builder.appendFrom(_, 0)) val buffer = builder.build() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala index 2cab5e0c44d92..272c0d4cb2335 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala @@ -59,7 +59,7 @@ class TgfSuite extends QueryTest { checkAnswer( inputData.generate(ExampleTGF()), Seq( - "michael is 29 years old" :: Nil, - "Next year, michael will be 30 years old" :: Nil)) + Row("michael is 29 years old"), + Row("Next year, michael will be 30 years old"))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 2bc9aede32f2a..94d14acccbb18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -229,13 +229,13 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), - (new java.math.BigDecimal("92233720368547758070"), - true, - 1.7976931348623157E308, - 10, - 21474836470L, - null, - "this is a simple string.") :: Nil + Row(new java.math.BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") ) } @@ -271,48 +271,49 @@ class JsonSuite extends QueryTest { // Access elements of a primitive array. checkAnswer( sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"), - ("str1", "str2", null) :: Nil + Row("str1", "str2", null) ) // Access an array of null values. checkAnswer( sql("select arrayOfNull from jsonTable"), - Seq(Seq(null, null, null, null)) :: Nil + Row(Seq(null, null, null, null)) ) // Access elements of a BigInteger array (we use DecimalType internally). checkAnswer( sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"), - (new java.math.BigDecimal("922337203685477580700"), - new java.math.BigDecimal("-922337203685477580800"), null) :: Nil + Row(new java.math.BigDecimal("922337203685477580700"), + new java.math.BigDecimal("-922337203685477580800"), null) ) // Access elements of an array of arrays. checkAnswer( sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"), - (Seq("1", "2", "3"), Seq("str1", "str2")) :: Nil + Row(Seq("1", "2", "3"), Seq("str1", "str2")) ) // Access elements of an array of arrays. checkAnswer( sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"), - (Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) :: Nil + Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) ) // Access elements of an array inside a filed with the type of ArrayType(ArrayType). checkAnswer( sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"), - ("str2", 2.1) :: Nil + Row("str2", 2.1) ) // Access elements of an array of structs. checkAnswer( sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " + "from jsonTable"), - (true :: "str1" :: null :: Nil, - false :: null :: null :: Nil, - null :: null :: null :: Nil, - null) :: Nil + Row( + Row(true, "str1", null), + Row(false, null, null), + Row(null, null, null), + null) ) // Access a struct and fields inside of it. @@ -327,13 +328,13 @@ class JsonSuite extends QueryTest { // Access an array field of a struct. checkAnswer( sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"), - (Seq(4, 5, 6), Seq("str1", "str2")) :: Nil + Row(Seq(4, 5, 6), Seq("str1", "str2")) ) // Access elements of an array field of a struct. checkAnswer( sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"), - (5, null) :: Nil + Row(5, null) ) } @@ -344,14 +345,14 @@ class JsonSuite extends QueryTest { // Right now, "field1" and "field2" are treated as aliases. We should fix it. checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), - (true, "str1") :: Nil + Row(true, "str1") ) // Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2. // Getting all values of a specific field from an array of structs. checkAnswer( sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), - (Seq(true, false), Seq("str1", null)) :: Nil + Row(Seq(true, false), Seq("str1", null)) ) } @@ -372,57 +373,57 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), - ("true", 11L, null, 1.1, "13.1", "str1") :: - ("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") :: - ("false", 21474836470L, new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: - (null, 21474836570L, new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil + Row("true", 11L, null, 1.1, "13.1", "str1") :: + Row("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") :: + Row("false", 21474836470L, new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: + Row(null, 21474836570L, new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil ) // Number and Boolean conflict: resolve the type as number in this query. checkAnswer( sql("select num_bool - 10 from jsonTable where num_bool > 11"), - 2 + Row(2) ) // Widening to LongType checkAnswer( sql("select num_num_1 - 100 from jsonTable where num_num_1 > 11"), - Seq(21474836370L) :: Seq(21474836470L) :: Nil + Row(21474836370L) :: Row(21474836470L) :: Nil ) checkAnswer( sql("select num_num_1 - 100 from jsonTable where num_num_1 > 10"), - Seq(-89) :: Seq(21474836370L) :: Seq(21474836470L) :: Nil + Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil ) // Widening to DecimalType checkAnswer( sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"), - Seq(new java.math.BigDecimal("21474836472.1")) :: Seq(new java.math.BigDecimal("92233720368547758071.2")) :: Nil + Row(new java.math.BigDecimal("21474836472.1")) :: Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil ) // Widening to DoubleType checkAnswer( sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"), - Seq(101.2) :: Seq(21474836471.2) :: Nil + Row(101.2) :: Row(21474836471.2) :: Nil ) // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 14"), - 92233720368547758071.2 + Row(92233720368547758071.2) ) // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"), - new java.math.BigDecimal("92233720368547758061.2").doubleValue + Row(new java.math.BigDecimal("92233720368547758061.2").doubleValue) ) // String and Boolean conflict: resolve the type as string. checkAnswer( sql("select * from jsonTable where str_bool = 'str1'"), - ("true", 11L, null, 1.1, "13.1", "str1") :: Nil + Row("true", 11L, null, 1.1, "13.1", "str1") ) } @@ -434,24 +435,24 @@ class JsonSuite extends QueryTest { // Number and Boolean conflict: resolve the type as boolean in this query. checkAnswer( sql("select num_bool from jsonTable where NOT num_bool"), - false + Row(false) ) checkAnswer( sql("select str_bool from jsonTable where NOT str_bool"), - false + Row(false) ) // Right now, the analyzer does not know that num_bool should be treated as a boolean. // Number and Boolean conflict: resolve the type as boolean in this query. checkAnswer( sql("select num_bool from jsonTable where num_bool"), - true + Row(true) ) checkAnswer( sql("select str_bool from jsonTable where str_bool"), - false + Row(false) ) // The plan of the following DSL is @@ -464,7 +465,7 @@ class JsonSuite extends QueryTest { jsonSchemaRDD. where('num_str > BigDecimal("92233720368547758060")). select('num_str + 1.2 as Symbol("num")), - new java.math.BigDecimal("92233720368547758061.2") + Row(new java.math.BigDecimal("92233720368547758061.2")) ) // The following test will fail. The type of num_str is StringType. @@ -475,7 +476,7 @@ class JsonSuite extends QueryTest { // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 13"), - Seq(14.3) :: Seq(92233720368547758071.2) :: Nil + Row(14.3) :: Row(92233720368547758071.2) :: Nil ) } @@ -496,10 +497,10 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), - (Seq(), "11", "[1,2,3]", Seq(null), "[]") :: - (null, """{"field":false}""", null, null, "{}") :: - (Seq(4, 5, 6), null, "str", Seq(null), "[7,8,9]") :: - (Seq(7), "{}","[str1,str2,33]", Seq("str"), """{"field":true}""") :: Nil + Row(Seq(), "11", "[1,2,3]", Row(null), "[]") :: + Row(null, """{"field":false}""", null, null, "{}") :: + Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") :: + Row(Seq(7), "{}","[str1,str2,33]", Row("str"), """{"field":true}""") :: Nil ) } @@ -518,16 +519,16 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), - Seq(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]", - """{"field":"str"}"""), Seq(Seq(214748364700L), Seq(1)), null) :: - Seq(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) :: - Seq(null, null, Seq("1", "2", "3")) :: Nil + Row(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]", + """{"field":"str"}"""), Seq(Row(214748364700L), Row(1)), null) :: + Row(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) :: + Row(null, null, Seq("1", "2", "3")) :: Nil ) // Treat an element as a number. checkAnswer( sql("select array1[0] + 1 from jsonTable where array1 is not null"), - 2 + Row(2) ) } @@ -568,13 +569,13 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), - (new java.math.BigDecimal("92233720368547758070"), + Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, 21474836470L, null, - "this is a simple string.") :: Nil + "this is a simple string.") ) } @@ -594,13 +595,13 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTableSQL"), - (new java.math.BigDecimal("92233720368547758070"), + Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, 21474836470L, null, - "this is a simple string.") :: Nil + "this is a simple string.") ) } @@ -626,13 +627,13 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable1"), - (new java.math.BigDecimal("92233720368547758070"), + Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, 21474836470L, null, - "this is a simple string.") :: Nil + "this is a simple string.") ) val jsonSchemaRDD2 = jsonRDD(primitiveFieldAndType, schema) @@ -643,13 +644,13 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable2"), - (new java.math.BigDecimal("92233720368547758070"), + Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, 21474836470L, null, - "this is a simple string.") :: Nil + "this is a simple string.") ) } @@ -659,7 +660,7 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), - (true, "str1") :: Nil + Row(true, "str1") ) checkAnswer( sql( @@ -667,7 +668,7 @@ class JsonSuite extends QueryTest { |select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1] |from jsonTable """.stripMargin), - ("str2", 6) :: Nil + Row("str2", 6) ) } @@ -681,7 +682,7 @@ class JsonSuite extends QueryTest { |select arrayOfArray1[0][0][0], arrayOfArray1[1][0][1], arrayOfArray1[1][1][0] |from jsonTable """.stripMargin), - (5, 7, 8) :: Nil + Row(5, 7, 8) ) checkAnswer( sql( @@ -690,7 +691,7 @@ class JsonSuite extends QueryTest { |arrayOfArray2[1][1][1].inner2[0], arrayOfArray2[2][0][0].inner3[0][0].inner4 |from jsonTable """.stripMargin), - ("str1", Nil, "str4", 2) :: Nil + Row("str1", Nil, "str4", 2) ) } @@ -704,10 +705,10 @@ class JsonSuite extends QueryTest { |select a, b, c |from jsonTable """.stripMargin), - ("str_a_1", null, null) :: - ("str_a_2", null, null) :: - (null, "str_b_3", null) :: - ("str_a_4", "str_b_4", "str_c_4") :: Nil + Row("str_a_1", null, null) :: + Row("str_a_2", null, null) :: + Row(null, "str_b_3", null) :: + Row("str_a_4", "str_b_4", "str_c_4") :: Nil ) } @@ -734,12 +735,12 @@ class JsonSuite extends QueryTest { |SELECT a, b, c, _unparsed |FROM jsonTable """.stripMargin), - (null, null, null, "{") :: - (null, null, null, "") :: - (null, null, null, """{"a":1, b:2}""") :: - (null, null, null, """{"a":{, b:3}""") :: - ("str_a_4", "str_b_4", "str_c_4", null) :: - (null, null, null, "]") :: Nil + Row(null, null, null, "{") :: + Row(null, null, null, "") :: + Row(null, null, null, """{"a":1, b:2}""") :: + Row(null, null, null, """{"a":{, b:3}""") :: + Row("str_a_4", "str_b_4", "str_c_4", null) :: + Row(null, null, null, "]") :: Nil ) checkAnswer( @@ -749,7 +750,7 @@ class JsonSuite extends QueryTest { |FROM jsonTable |WHERE _unparsed IS NULL """.stripMargin), - ("str_a_4", "str_b_4", "str_c_4") :: Nil + Row("str_a_4", "str_b_4", "str_c_4") ) checkAnswer( @@ -759,11 +760,11 @@ class JsonSuite extends QueryTest { |FROM jsonTable |WHERE _unparsed IS NOT NULL """.stripMargin), - Seq("{") :: - Seq("") :: - Seq("""{"a":1, b:2}""") :: - Seq("""{"a":{, b:3}""") :: - Seq("]") :: Nil + Row("{") :: + Row("") :: + Row("""{"a":1, b:2}""") :: + Row("""{"a":{, b:3}""") :: + Row("]") :: Nil ) TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) @@ -793,10 +794,10 @@ class JsonSuite extends QueryTest { |SELECT field1, field2, field3, field4 |FROM jsonTable """.stripMargin), - Seq(Seq(Seq(null), Seq(Seq(Seq("Test")))), null, null, null) :: - Seq(null, Seq(null, Seq(Seq(1))), null, null) :: - Seq(null, null, Seq(Seq(null), Seq(Seq("2"))), null) :: - Seq(null, null, null, Seq(Seq(null, Seq(1, 2, 3)))) :: Nil + Row(Seq(Seq(null), Seq(Seq(Seq("Test")))), null, null, null) :: + Row(null, Seq(null, Seq(Row(1))), null, null) :: + Row(null, null, Seq(Seq(null), Seq(Row("2"))), null) :: + Row(null, null, null, Seq(Seq(null, Seq(1, 2, 3)))) :: Nil ) } @@ -851,12 +852,12 @@ class JsonSuite extends QueryTest { primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), - (new java.math.BigDecimal("92233720368547758070"), + Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, 21474836470L, - "this is a simple string.") :: Nil + "this is a simple string.") ) val complexJsonSchemaRDD = jsonRDD(complexFieldAndType1) @@ -865,38 +866,38 @@ class JsonSuite extends QueryTest { // Access elements of a primitive array. checkAnswer( sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from complexTable"), - ("str1", "str2", null) :: Nil + Row("str1", "str2", null) ) // Access an array of null values. checkAnswer( sql("select arrayOfNull from complexTable"), - Seq(Seq(null, null, null, null)) :: Nil + Row(Seq(null, null, null, null)) ) // Access elements of a BigInteger array (we use DecimalType internally). checkAnswer( sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from complexTable"), - (new java.math.BigDecimal("922337203685477580700"), - new java.math.BigDecimal("-922337203685477580800"), null) :: Nil + Row(new java.math.BigDecimal("922337203685477580700"), + new java.math.BigDecimal("-922337203685477580800"), null) ) // Access elements of an array of arrays. checkAnswer( sql("select arrayOfArray1[0], arrayOfArray1[1] from complexTable"), - (Seq("1", "2", "3"), Seq("str1", "str2")) :: Nil + Row(Seq("1", "2", "3"), Seq("str1", "str2")) ) // Access elements of an array of arrays. checkAnswer( sql("select arrayOfArray2[0], arrayOfArray2[1] from complexTable"), - (Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) :: Nil + Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) ) // Access elements of an array inside a filed with the type of ArrayType(ArrayType). checkAnswer( sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from complexTable"), - ("str2", 2.1) :: Nil + Row("str2", 2.1) ) // Access a struct and fields inside of it. @@ -911,13 +912,13 @@ class JsonSuite extends QueryTest { // Access an array field of a struct. checkAnswer( sql("select structWithArrayFields.field1, structWithArrayFields.field2 from complexTable"), - (Seq(4, 5, 6), Seq("str1", "str2")) :: Nil + Row(Seq(4, 5, 6), Seq("str1", "str2")) ) // Access elements of an array field of a struct. checkAnswer( sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from complexTable"), - (5, null) :: Nil + Row(5, null) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 4c3a04506ce42..4ad8c472007fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -46,7 +46,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { predicate: Predicate, filterClass: Class[_ <: FilterPredicate], checker: (SchemaRDD, Any) => Unit, - expectedResult: => Any): Unit = { + expectedResult: Any): Unit = { withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { val query = rdd.select(output.map(_.attr): _*).where(predicate) @@ -65,11 +65,20 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } } - private def checkFilterPushdown + private def checkFilterPushdown1 (rdd: SchemaRDD, output: Symbol*) (predicate: Predicate, filterClass: Class[_ <: FilterPredicate]) - (expectedResult: => Any): Unit = { - checkFilterPushdown(rdd, output, predicate, filterClass, checkAnswer _, expectedResult) + (expectedResult: => Seq[Row]): Unit = { + checkFilterPushdown(rdd, output, predicate, filterClass, + (query, expected) => checkAnswer(query, expected.asInstanceOf[Seq[Row]]), expectedResult) + } + + private def checkFilterPushdown + (rdd: SchemaRDD, output: Symbol*) + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate]) + (expectedResult: Int): Unit = { + checkFilterPushdown(rdd, output, predicate, filterClass, + (query, expected) => checkAnswer(query, expected.asInstanceOf[Seq[Row]]), Seq(Row(expectedResult))) } def checkBinaryFilterPushdown @@ -89,27 +98,25 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - boolean") { withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Boolean]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Boolean]]) { + checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Boolean]])(Seq.empty[Row]) + checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Boolean]]) { Seq(Row(true), Row(false)) } - checkFilterPushdown(rdd, '_1)('_1 === true, classOf[Eq[java.lang.Boolean]])(true) - checkFilterPushdown(rdd, '_1)('_1 !== true, classOf[Operators.NotEq[java.lang.Boolean]]) { - false - } + checkFilterPushdown1(rdd, '_1)('_1 === true, classOf[Eq[java.lang.Boolean]])(Seq(Row(true))) + checkFilterPushdown1(rdd, '_1)('_1 !== true, classOf[Operators.NotEq[java.lang.Boolean]])(Seq(Row(false))) } } test("filter pushdown - integer") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[Integer]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[Integer]]) { + checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[Integer]])(Seq.empty[Row]) + checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[Integer]]) { (1 to 4).map(Row.apply(_)) } checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[Integer]])(1) - checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[Integer]]) { + checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[Integer]]) { (2 to 4).map(Row.apply(_)) } @@ -126,7 +133,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[Integer]])(4) checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { + checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { Seq(Row(1), Row(4)) } } @@ -134,13 +141,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - long") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Long]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Long]]) { + checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Long]])(Seq.empty[Row]) + checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Long]]) { (1 to 4).map(Row.apply(_)) } checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Long]])(1) - checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Long]]) { + checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Long]]) { (2 to 4).map(Row.apply(_)) } @@ -157,7 +164,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Long]])(4) checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { + checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { Seq(Row(1), Row(4)) } } @@ -165,13 +172,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - float") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Float]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Float]]) { + checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Float]])(Seq.empty[Row]) + checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Float]]) { (1 to 4).map(Row.apply(_)) } checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Float]])(1) - checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Float]]) { + checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Float]]) { (2 to 4).map(Row.apply(_)) } @@ -188,7 +195,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Float]])(4) checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { + checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { Seq(Row(1), Row(4)) } } @@ -196,13 +203,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - double") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Double]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Double]]) { + checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Double]])(Seq.empty[Row]) + checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Double]]) { (1 to 4).map(Row.apply(_)) } checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Double]])(1) - checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Double]]) { + checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Double]]) { (2 to 4).map(Row.apply(_)) } @@ -219,7 +226,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Double]])(4) checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { + checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { Seq(Row(1), Row(4)) } } @@ -227,30 +234,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - string") { withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) { + checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row]) + checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) { (1 to 4).map(i => Row.apply(i.toString)) } - checkFilterPushdown(rdd, '_1)('_1 === "1", classOf[Eq[String]])("1") - checkFilterPushdown(rdd, '_1)('_1 !== "1", classOf[Operators.NotEq[String]]) { + checkFilterPushdown1(rdd, '_1)('_1 === "1", classOf[Eq[String]])(Seq(Row("1"))) + checkFilterPushdown1(rdd, '_1)('_1 !== "1", classOf[Operators.NotEq[String]]) { (2 to 4).map(i => Row.apply(i.toString)) } - checkFilterPushdown(rdd, '_1)('_1 < "2", classOf[Lt [java.lang.String]])("1") - checkFilterPushdown(rdd, '_1)('_1 > "3", classOf[Gt [java.lang.String]])("4") - checkFilterPushdown(rdd, '_1)('_1 <= "1", classOf[LtEq[java.lang.String]])("1") - checkFilterPushdown(rdd, '_1)('_1 >= "4", classOf[GtEq[java.lang.String]])("4") + checkFilterPushdown1(rdd, '_1)('_1 < "2", classOf[Lt [java.lang.String]])(Seq(Row("1"))) + checkFilterPushdown1(rdd, '_1)('_1 > "3", classOf[Gt [java.lang.String]])(Seq(Row("4"))) + checkFilterPushdown1(rdd, '_1)('_1 <= "1", classOf[LtEq[java.lang.String]])(Seq(Row("1"))) + checkFilterPushdown1(rdd, '_1)('_1 >= "4", classOf[GtEq[java.lang.String]])(Seq(Row("4"))) - checkFilterPushdown(rdd, '_1)(Literal("1") === '_1, classOf[Eq [java.lang.String]])("1") - checkFilterPushdown(rdd, '_1)(Literal("2") > '_1, classOf[Lt [java.lang.String]])("1") - checkFilterPushdown(rdd, '_1)(Literal("3") < '_1, classOf[Gt [java.lang.String]])("4") - checkFilterPushdown(rdd, '_1)(Literal("1") >= '_1, classOf[LtEq[java.lang.String]])("1") - checkFilterPushdown(rdd, '_1)(Literal("4") <= '_1, classOf[GtEq[java.lang.String]])("4") + checkFilterPushdown1(rdd, '_1)(Literal("1") === '_1, classOf[Eq [java.lang.String]])(Seq(Row("1"))) + checkFilterPushdown1(rdd, '_1)(Literal("2") > '_1, classOf[Lt [java.lang.String]])(Seq(Row("1"))) + checkFilterPushdown1(rdd, '_1)(Literal("3") < '_1, classOf[Gt [java.lang.String]])(Seq(Row("4"))) + checkFilterPushdown1(rdd, '_1)(Literal("1") >= '_1, classOf[LtEq[java.lang.String]])(Seq(Row("1"))) + checkFilterPushdown1(rdd, '_1)(Literal("4") <= '_1, classOf[GtEq[java.lang.String]])(Seq(Row("4"))) - checkFilterPushdown(rdd, '_1)(!('_1 < "4"), classOf[Operators.GtEq[java.lang.String]])("4") - checkFilterPushdown(rdd, '_1)('_1 > "2" && '_1 < "4", classOf[Operators.And])("3") - checkFilterPushdown(rdd, '_1)('_1 < "2" || '_1 > "3", classOf[Operators.Or]) { + checkFilterPushdown1(rdd, '_1)(!('_1 < "4"), classOf[Operators.GtEq[java.lang.String]])(Seq(Row("4"))) + checkFilterPushdown1(rdd, '_1)('_1 > "2" && '_1 < "4", classOf[Operators.And])(Seq(Row("3"))) + checkFilterPushdown1(rdd, '_1)('_1 < "2" || '_1 > "3", classOf[Operators.Or]) { Seq(Row("1"), Row("4")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 973819aaa4d77..a57e4e85a35ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -68,8 +68,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest { /** * Writes `data` to a Parquet file, reads it back and check file contents. */ - protected def checkParquetFile[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { - withParquetRDD(data)(checkAnswer(_, data)) + protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data: Seq[T]): Unit = { + withParquetRDD(data)(r => checkAnswer(r, data.map(Row.fromTuple))) } test("basic data types (without binary)") { @@ -143,7 +143,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { withParquetRDD(data) { rdd => // Structs are converted to `Row`s checkAnswer(rdd, data.map { case Tuple1(struct) => - Tuple1(Row(struct.productIterator.toSeq: _*)) + Row(Row(struct.productIterator.toSeq: _*)) }) } } @@ -153,7 +153,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { withParquetRDD(data) { rdd => // Structs are converted to `Row`s checkAnswer(rdd, data.map { case Tuple1(struct) => - Tuple1(Row(struct.productIterator.toSeq: _*)) + Row(Row(struct.productIterator.toSeq: _*)) }) } } @@ -162,7 +162,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) withParquetRDD(data) { rdd => checkAnswer(rdd, data.map { case Tuple1(m) => - Tuple1(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) + Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) }) } } @@ -261,7 +261,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { val path = new Path(dir.toURI.toString, "part-r-0.parquet") makeRawParquetFile(path) checkAnswer(parquetFile(path.toString), (0 until 10).map { i => - (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 3a073a6b7057e..2c5345b1f9148 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -28,7 +28,7 @@ import parquet.hadoop.util.ContextUtil import parquet.io.api.Binary import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Row => _, _} import org.apache.spark.sql.catalyst.util.getTempFilePath import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ @@ -191,8 +191,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA parquetFile(path).registerTempTable("tmp") checkAnswer( sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) + Row(5, "val_5") :: + Row(7, "val_7") :: Nil) Utils.deleteRecursively(file) @@ -207,8 +207,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA parquetFile(path).registerTempTable("tmp") checkAnswer( sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) + Row(5, "val_5") :: + Row(7, "val_7") :: Nil) Utils.deleteRecursively(file) @@ -223,8 +223,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA parquetFile(path).registerTempTable("tmp") checkAnswer( sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) + Row(5, "val_5") :: + Row(7, "val_7") :: Nil) Utils.deleteRecursively(file) @@ -239,8 +239,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA parquetFile(path).registerTempTable("tmp") checkAnswer( sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) + Row(5, "val_5") :: + Row(7, "val_7") :: Nil) Utils.deleteRecursively(file) @@ -255,8 +255,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA parquetFile(path).registerTempTable("tmp") checkAnswer( sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) + Row(5, "val_5") :: + Row(7, "val_7") :: Nil) Utils.deleteRecursively(file) @@ -303,7 +303,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result.size === 9, "self-join result has incorrect size") assert(result(0).size === 12, "result row has incorrect size") result.zipWithIndex.foreach { - case (row, index) => row.zipWithIndex.foreach { + case (row, index) => row.toSeq.zipWithIndex.foreach { case (field, column) => assert(field != null, s"self-join contains null value in row $index field $column") } } @@ -423,7 +423,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val readFile = parquetFile(path) val rdd_saved = readFile.collect() - assert(rdd_saved(0) === Seq.fill(5)(null)) + assert(rdd_saved(0) === Row(null, null, null, null, null)) Utils.deleteRecursively(file) assert(true) } @@ -438,7 +438,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val readFile = parquetFile(path) val rdd_saved = readFile.collect() - assert(rdd_saved(0) === Seq.fill(5)(null)) + assert(rdd_saved(0) === Row(null, null, null, null, null)) Utils.deleteRecursively(file) assert(true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala index 4c081fb4510b2..7b3f8c22af2db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala @@ -38,7 +38,7 @@ class ParquetQuerySuite2 extends QueryTest with ParquetTest { val data = (0 until 10).map(i => (i, i.toString)) withParquetTable(data, "t") { sql("INSERT INTO t SELECT * FROM t") - checkAnswer(table("t"), data ++ data) + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 264f6d94c4ed9..b1e0919b7aed1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -244,7 +244,7 @@ class TableScanSuite extends DataSourceTest { sqlTest( "SELECT count(*) FROM tableWithSchema", - 10) + Seq(Row(10))) sqlTest( "SELECT `string$%Field` FROM tableWithSchema", @@ -260,7 +260,7 @@ class TableScanSuite extends DataSourceTest { sqlTest( "SELECT structFieldSimple.key, arrayFieldSimple[1] FROM tableWithSchema a where int_Field=1", - Seq(Seq(1, 2))) + Seq(Row(1, 2))) sqlTest( "SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 10833c113216a..3e26fe3675768 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -368,10 +368,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { .mkString("\t") } case command: ExecutedCommand => - command.executeCollect().map(_.head.toString) + command.executeCollect().map(_(0).toString) case other => - val result: Seq[Seq[Any]] = other.executeCollect().toSeq + val result: Seq[Seq[Any]] = other.executeCollect().map(_.toSeq).toSeq // We need the types so we can output struct field names val types = analyzed.output.map(_.dataType) // Reformat to match hive tab delimited output. @@ -395,7 +395,7 @@ private object HiveContext { protected[sql] def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => - struct.zip(fields).map { + struct.toSeq.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => @@ -418,7 +418,7 @@ private object HiveContext { /** Hive outputs fields of structs slightly differently than top level attributes. */ protected def toHiveStructString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => - struct.zip(fields).map { + struct.toSeq.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index eeabfdd857916..82dba99900df9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -348,7 +348,7 @@ private[hive] trait HiveInspectors { (o: Any) => { if (o != null) { val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row].toSeq).zipped.foreach { (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) } struct @@ -432,7 +432,7 @@ private[hive] trait HiveInspectors { } case x: SettableStructObjectInspector => val fieldRefs = x.getAllStructFieldRefs - val row = a.asInstanceOf[Seq[_]] + val row = a.asInstanceOf[Row] // 1. create the pojo (most likely) object val result = x.create() var i = 0 @@ -448,7 +448,7 @@ private[hive] trait HiveInspectors { result case x: StructObjectInspector => val fieldRefs = x.getAllStructFieldRefs - val row = a.asInstanceOf[Seq[_]] + val row = a.asInstanceOf[Row] val result = new java.util.ArrayList[AnyRef](fieldRefs.length) var i = 0 while (i < fieldRefs.length) { @@ -475,7 +475,7 @@ private[hive] trait HiveInspectors { } def wrap( - row: Seq[Any], + row: Row, inspectors: Seq[ObjectInspector], cache: Array[AnyRef]): Array[AnyRef] = { var i = 0 @@ -486,6 +486,18 @@ private[hive] trait HiveInspectors { cache } + def wrap( + row: Seq[Any], + inspectors: Seq[ObjectInspector], + cache: Array[AnyRef]): Array[AnyRef] = { + var i = 0 + while (i < inspectors.length) { + cache(i) = wrap(row(i), inspectors(i)) + i += 1 + } + cache + } + /** * @param dataType Catalyst data type * @return Hive java object inspector (recursively), not the Writable ObjectInspector diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index d898b876c39f8..76d2140372197 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -360,7 +360,7 @@ private[hive] case class HiveUdafFunction( protected lazy val cached = new Array[AnyRef](exprs.length) def update(input: Row): Unit = { - val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray + val inputs = inputProjection(input) function.iterate(buffer, wrap(inputs, inspectors, cached)) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index cc8bb3e172c6e..aae175e426ade 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -209,7 +209,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = { val dynamicPartPath = dynamicPartColNames - .zip(row.takeRight(dynamicPartColNames.length)) + .zip(row.toSeq.takeRight(dynamicPartColNames.length)) .map { case (col, rawVal) => val string = if (rawVal == null) null else String.valueOf(rawVal) s"/$col=${if (string == null || string.isEmpty) defaultPartName else string}" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala index f89c49d292c6c..f320d732fb77a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util._ * So, we duplicate this code here. */ class QueryTest extends PlanTest { + /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer @@ -56,17 +57,20 @@ class QueryTest extends PlanTest { * @param rdd the [[SchemaRDD]] to be executed * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. */ - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Any): Unit = { - val convertedAnswer = expectedAnswer match { - case s: Seq[_] if s.isEmpty => s - case s: Seq[_] if s.head.isInstanceOf[Product] && - !s.head.isInstanceOf[Seq[_]] => s.map(_.asInstanceOf[Product].productIterator.toIndexedSeq) - case s: Seq[_] => s - case singleItem => Seq(Seq(singleItem)) + protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = { + val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case o => o + }) + } + if (!isSorted) converted.sortBy(_.toString) else converted } - - val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s}.nonEmpty - def prepareAnswer(answer: Seq[Any]) = if (!isSorted) answer.sortBy(_.toString) else answer val sparkAnswer = try rdd.collect().toSeq catch { case e: Exception => fail( @@ -74,11 +78,12 @@ class QueryTest extends PlanTest { |Exception thrown while executing query: |${rdd.queryExecution} |== Exception == - |${stackTraceToString(e)} + |$e + |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} """.stripMargin) } - if(prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) { + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { fail(s""" |Results do not match for query: |${rdd.logicalPlan} @@ -88,11 +93,22 @@ class QueryTest extends PlanTest { |${rdd.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${convertedAnswer.size} ==" +: - prepareAnswer(convertedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} """.stripMargin) } } + + protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = { + checkAnswer(rdd, Seq(expectedAnswer)) + } + + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { + test(sqlString) { + checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + } + } + } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 4864607252034..2d3ff680125ad 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -129,6 +129,12 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { } } + def checkValues(row1: Seq[Any], row2: Row): Unit = { + row1.zip(row2.toSeq).map { + case (r1, r2) => checkValue(r1, r2) + } + } + def checkValue(v1: Any, v2: Any): Unit = { (v1, v2) match { case (r1: Decimal, r2: Decimal) => @@ -198,7 +204,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { case (t, idx) => StructField(s"c_$idx", t) }) - checkValues(row, unwrap(wrap(row, toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) + checkValues(row, unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 7cfb875e05db3..0e6636d38ed3c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -43,7 +43,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the table has also been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq + testData.collect().toSeq.map(Row.fromTuple) ) // Add more data. @@ -52,7 +52,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the table has been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq ++ testData.collect().toSeq + testData.toSchemaRDD.collect().toSeq ++ testData.toSchemaRDD.collect().toSeq ) // Now overwrite. @@ -61,7 +61,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the registered table has also been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq + testData.collect().toSeq.map(Row.fromTuple) ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 53d8aa7739bc2..7408c7ffd69e8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -155,7 +155,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), - ("a", "b") :: Nil) + Row("a", "b")) FileUtils.deleteDirectory(tempDir) sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) @@ -164,14 +164,14 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { // will show. checkAnswer( sql("SELECT * FROM jsonTable"), - ("a1", "b1") :: Nil) + Row("a1", "b1")) refreshTable("jsonTable") // Check that the refresh worked checkAnswer( sql("SELECT * FROM jsonTable"), - ("a1", "b1", "c1") :: Nil) + Row("a1", "b1", "c1")) FileUtils.deleteDirectory(tempDir) } @@ -191,7 +191,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), - ("a", "b") :: Nil) + Row("a", "b")) FileUtils.deleteDirectory(tempDir) sparkContext.parallelize(("a", "b", "c") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) @@ -210,7 +210,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { // New table should reflect new schema. checkAnswer( sql("SELECT * FROM jsonTable"), - ("a", "b", "c") :: Nil) + Row("a", "b", "c")) FileUtils.deleteDirectory(tempDir) } @@ -253,6 +253,6 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |) """.stripMargin) - sql("DROP TABLE jsonTable").collect.foreach(println) + sql("DROP TABLE jsonTable").collect().foreach(println) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 0b4e76c9d3d2f..6f07fd5a879c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag -import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -141,7 +141,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { before: () => Unit, after: () => Unit, query: String, - expectedAnswer: Seq[Any], + expectedAnswer: Seq[Row], ct: ClassTag[_]) = { before() @@ -183,7 +183,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { /** Tests for MetastoreRelation */ val metastoreQuery = """SELECT * FROM src a JOIN src b ON a.key = 238 AND a.key = b.key""" - val metastoreAnswer = Seq.fill(4)((238, "val_238", 238, "val_238")) + val metastoreAnswer = Seq.fill(4)(Row(238, "val_238", 238, "val_238")) mkTest( () => (), () => (), @@ -197,7 +197,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val leftSemiJoinQuery = """SELECT * FROM src a |left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin - val answer = (86, "val_86") :: Nil + val answer = Row(86, "val_86") var rdd = sql(leftSemiJoinQuery) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index c14f0d24e0dc3..df72be7746ac6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -226,7 +226,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Jdk version leads to different query output for double, so not use createQueryTest here test("division") { val res = sql("SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1").collect().head - Seq(2.0, 0.5, 0.3333333333333333, 0.002).zip(res).foreach( x => + Seq(2.0, 0.5, 0.3333333333333333, 0.002).zip(res.toSeq).foreach( x => assert(x._1 == x._2.asInstanceOf[Double])) } @@ -235,7 +235,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("Query expressed in SQL") { setConf("spark.sql.dialect", "sql") - assert(sql("SELECT 1").collect() === Array(Seq(1))) + assert(sql("SELECT 1").collect() === Array(Row(1))) setConf("spark.sql.dialect", "hiveql") } @@ -467,7 +467,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TestData(2, "str2") :: Nil) testData.registerTempTable("REGisteredTABle") - assertResult(Array(Array(2, "str2"))) { + assertResult(Array(Row(2, "str2"))) { sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + "WHERE TableAliaS.a > 1").collect() } @@ -553,12 +553,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Describe a table assertResult( Array( - Array("key", "int", null), - Array("value", "string", null), - Array("dt", "string", null), - Array("# Partition Information", "", ""), - Array("# col_name", "data_type", "comment"), - Array("dt", "string", null)) + Row("key", "int", null), + Row("value", "string", null), + Row("dt", "string", null), + Row("# Partition Information", "", ""), + Row("# col_name", "data_type", "comment"), + Row("dt", "string", null)) ) { sql("DESCRIBE test_describe_commands1") .select('col_name, 'data_type, 'comment) @@ -568,12 +568,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Describe a table with a fully qualified table name assertResult( Array( - Array("key", "int", null), - Array("value", "string", null), - Array("dt", "string", null), - Array("# Partition Information", "", ""), - Array("# col_name", "data_type", "comment"), - Array("dt", "string", null)) + Row("key", "int", null), + Row("value", "string", null), + Row("dt", "string", null), + Row("# Partition Information", "", ""), + Row("# col_name", "data_type", "comment"), + Row("dt", "string", null)) ) { sql("DESCRIBE default.test_describe_commands1") .select('col_name, 'data_type, 'comment) @@ -623,8 +623,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assertResult( Array( - Array("a", "IntegerType", null), - Array("b", "StringType", null)) + Row("a", "IntegerType", null), + Row("b", "StringType", null)) ) { sql("DESCRIBE test_describe_commands2") .select('col_name, 'data_type, 'comment) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index 5dafcd6c0a76a..f2374a215291b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -64,7 +64,7 @@ class HiveUdfSuite extends QueryTest { test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") { checkAnswer( sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"), - 8 + Row(8) ) } @@ -115,7 +115,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'") checkAnswer( sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(), - Seq(Seq("1"), Seq("2"))) + Seq(Row("1"), Row("2"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") TestHive.reset() @@ -131,7 +131,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") checkAnswer( sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(), - Seq(Seq(0), Seq(2), Seq(13))) + Seq(Row(0), Row(2), Row(13))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") TestHive.reset() @@ -146,7 +146,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") checkAnswer( sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(), - Seq(Seq("a,b,c"), Seq("d,e"))) + Seq(Row("a,b,c"), Row("d,e"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") TestHive.reset() @@ -160,7 +160,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") checkAnswer( sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(), - Seq(Seq("hello world"), Seq("hello goodbye"))) + Seq(Row("hello world"), Row("hello goodbye"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") TestHive.reset() @@ -177,7 +177,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") checkAnswer( sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(), - Seq(Seq("0, 0"), Seq("2, 2"), Seq("13, 13"))) + Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") TestHive.reset() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index d41eb9e870bf0..f6bf2dbb5d6e4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -41,7 +41,7 @@ class SQLQuerySuite extends QueryTest { } test("CTAS with serde") { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() sql( """CREATE TABLE ctas2 | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" @@ -51,23 +51,23 @@ class SQLQuerySuite extends QueryTest { | AS | SELECT key, value | FROM src - | ORDER BY key, value""".stripMargin).collect + | ORDER BY key, value""".stripMargin).collect() sql( """CREATE TABLE ctas3 | ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\012' | STORED AS textfile AS | SELECT key, value | FROM src - | ORDER BY key, value""".stripMargin).collect + | ORDER BY key, value""".stripMargin).collect() // the table schema may like (key: integer, value: string) sql( """CREATE TABLE IF NOT EXISTS ctas4 AS - | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect + | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect() // do nothing cause the table ctas4 already existed. sql( """CREATE TABLE IF NOT EXISTS ctas4 AS - | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect + | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() checkAnswer( sql("SELECT k, value FROM ctas1 ORDER BY k, value"), @@ -89,7 +89,7 @@ class SQLQuerySuite extends QueryTest { intercept[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] { sql( """CREATE TABLE ctas4 AS - | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect + | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() } checkAnswer( sql("SELECT key, value FROM ctas4 ORDER BY key, value"), @@ -126,7 +126,7 @@ class SQLQuerySuite extends QueryTest { sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested") checkAnswer( sql("SELECT f1.f2.f3 FROM nested"), - 1) + Row(1)) checkAnswer(sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested"), Seq.empty[Row]) checkAnswer( @@ -233,7 +233,7 @@ class SQLQuerySuite extends QueryTest { | (s struct, | innerArray:array, | innerMap: map>) - """.stripMargin).collect + """.stripMargin).collect() sql( """ @@ -243,7 +243,7 @@ class SQLQuerySuite extends QueryTest { checkAnswer( sql("SELECT * FROM nullValuesInInnerComplexTypes"), - Seq(Seq(Seq(null, null, null))) + Row(Row(null, null, null)) ) sql("DROP TABLE nullValuesInInnerComplexTypes") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 4bc14bad0ad5f..581f666399492 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -39,7 +39,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test("SELECT on Parquet table") { val data = (1 to 4).map(i => (i, s"val_$i")) withParquetTable(data, "t") { - checkAnswer(sql("SELECT * FROM t"), data) + checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala index 8bbb7f2fdbf48..79fd99d9f89ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala @@ -177,81 +177,81 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll test(s"ordering of the partitioning columns $table") { checkAnswer( sql(s"SELECT p, stringField FROM $table WHERE p = 1"), - Seq.fill(10)((1, "part-1")) + Seq.fill(10)(Row(1, "part-1")) ) checkAnswer( sql(s"SELECT stringField, p FROM $table WHERE p = 1"), - Seq.fill(10)(("part-1", 1)) + Seq.fill(10)(Row("part-1", 1)) ) } test(s"project the partitioning column $table") { checkAnswer( sql(s"SELECT p, count(*) FROM $table group by p"), - (1, 10) :: - (2, 10) :: - (3, 10) :: - (4, 10) :: - (5, 10) :: - (6, 10) :: - (7, 10) :: - (8, 10) :: - (9, 10) :: - (10, 10) :: Nil + Row(1, 10) :: + Row(2, 10) :: + Row(3, 10) :: + Row(4, 10) :: + Row(5, 10) :: + Row(6, 10) :: + Row(7, 10) :: + Row(8, 10) :: + Row(9, 10) :: + Row(10, 10) :: Nil ) } test(s"project partitioning and non-partitioning columns $table") { checkAnswer( sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), - ("part-1", 1, 10) :: - ("part-2", 2, 10) :: - ("part-3", 3, 10) :: - ("part-4", 4, 10) :: - ("part-5", 5, 10) :: - ("part-6", 6, 10) :: - ("part-7", 7, 10) :: - ("part-8", 8, 10) :: - ("part-9", 9, 10) :: - ("part-10", 10, 10) :: Nil + Row("part-1", 1, 10) :: + Row("part-2", 2, 10) :: + Row("part-3", 3, 10) :: + Row("part-4", 4, 10) :: + Row("part-5", 5, 10) :: + Row("part-6", 6, 10) :: + Row("part-7", 7, 10) :: + Row("part-8", 8, 10) :: + Row("part-9", 9, 10) :: + Row("part-10", 10, 10) :: Nil ) } test(s"simple count $table") { checkAnswer( sql(s"SELECT COUNT(*) FROM $table"), - 100) + Row(100)) } test(s"pruned count $table") { checkAnswer( sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), - 10) + Row(10)) } test(s"non-existant partition $table") { checkAnswer( sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), - 0) + Row(0)) } test(s"multi-partition pruned count $table") { checkAnswer( sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), - 30) + Row(30)) } test(s"non-partition predicates $table") { checkAnswer( sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), - 30) + Row(30)) } test(s"sum $table") { checkAnswer( sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), - 1 + 2 + 3) + Row(1 + 2 + 3)) } test(s"hive udfs $table") { @@ -266,6 +266,6 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll test("non-part select(*)") { checkAnswer( sql("SELECT COUNT(*) FROM normal_parquet"), - 10) + Row(10)) } } From 2f82c841fa1bd866dab2eeb8ab48bc3bb801ab52 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 20 Jan 2015 15:20:20 -0800 Subject: [PATCH 06/27] [SPARK-5186] [MLLIB] Vector.equals and Vector.hashCode are very inefficient JIRA Issue: https://issues.apache.org/jira/browse/SPARK-5186 Currently SparseVector is using the inherited equals from Vector, which will create a full-size array for even the sparse vector. The pull request contains a specialized equals optimization that improves on both time and space. 1. The implementation will be consistent with the original. Especially it will keep equality comparison between SparseVector and DenseVector. Author: Yuhao Yang Author: Yuhao Yang Closes #3997 from hhbyyh/master and squashes the following commits: 0d9d130 [Yuhao Yang] function name change and ut update 93f0d46 [Yuhao Yang] unify sparse vs dense vectors 985e160 [Yuhao Yang] improve locality for equals bdf8789 [Yuhao Yang] improve equals and rewrite hashCode for Vector a6952c3 [Yuhao Yang] fix scala style for comments 50abef3 [Yuhao Yang] fix ut for sparse vector with explicit 0 f41b135 [Yuhao Yang] iterative equals for sparse vector 5741144 [Yuhao Yang] Specialized equals for SparseVector --- .../apache/spark/mllib/linalg/Vectors.scala | 55 ++++++++++++++++++- .../spark/mllib/linalg/VectorsSuite.scala | 18 ++++++ 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index adbd8266ed6fa..7ee0224ad4662 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -50,13 +50,35 @@ sealed trait Vector extends Serializable { override def equals(other: Any): Boolean = { other match { - case v: Vector => - util.Arrays.equals(this.toArray, v.toArray) + case v2: Vector => { + if (this.size != v2.size) return false + (this, v2) match { + case (s1: SparseVector, s2: SparseVector) => + Vectors.equals(s1.indices, s1.values, s2.indices, s2.values) + case (s1: SparseVector, d1: DenseVector) => + Vectors.equals(s1.indices, s1.values, 0 until d1.size, d1.values) + case (d1: DenseVector, s1: SparseVector) => + Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values) + case (_, _) => util.Arrays.equals(this.toArray, v2.toArray) + } + } case _ => false } } - override def hashCode(): Int = util.Arrays.hashCode(this.toArray) + override def hashCode(): Int = { + var result: Int = size + 31 + this.foreachActive { case (index, value) => + // ignore explict 0 for comparison between sparse and dense + if (value != 0) { + result = 31 * result + index + // refer to {@link java.util.Arrays.equals} for hash algorithm + val bits = java.lang.Double.doubleToLongBits(value) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } + return result + } /** * Converts the instance to a breeze vector. @@ -392,6 +414,33 @@ object Vectors { } squaredDistance } + + /** + * Check equality between sparse/dense vectors + */ + private[mllib] def equals( + v1Indices: IndexedSeq[Int], + v1Values: Array[Double], + v2Indices: IndexedSeq[Int], + v2Values: Array[Double]): Boolean = { + val v1Size = v1Values.size + val v2Size = v2Values.size + var k1 = 0 + var k2 = 0 + var allEqual = true + while (allEqual) { + while (k1 < v1Size && v1Values(k1) == 0) k1 += 1 + while (k2 < v2Size && v2Values(k2) == 0) k2 += 1 + + if (k1 >= v1Size || k2 >= v2Size) { + return k1 >= v1Size && k2 >= v2Size // check end alignment + } + allEqual = v1Indices(k1) == v2Indices(k2) && v1Values(k1) == v2Values(k2) + k1 += 1 + k2 += 1 + } + allEqual + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 85ac8ccebfc59..5def899cea117 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -89,6 +89,24 @@ class VectorsSuite extends FunSuite { } } + test("vectors equals with explicit 0") { + val dv1 = Vectors.dense(Array(0, 0.9, 0, 0.8, 0)) + val sv1 = Vectors.sparse(5, Array(1, 3), Array(0.9, 0.8)) + val sv2 = Vectors.sparse(5, Array(0, 1, 2, 3, 4), Array(0, 0.9, 0, 0.8, 0)) + + val vectors = Seq(dv1, sv1, sv2) + for (v <- vectors; u <- vectors) { + assert(v === u) + assert(v.## === u.##) + } + + val another = Vectors.sparse(5, Array(0, 1, 3), Array(0, 0.9, 0.2)) + for (v <- vectors) { + assert(v != another) + assert(v.## != another.##) + } + } + test("indexing dense vectors") { val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0) assert(vec(0) === 1.0) From 9a151ce58b3e756f205c9f3ebbbf3ab0ba5b33fd Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 20 Jan 2015 16:40:46 -0800 Subject: [PATCH 07/27] [SPARK-5294][WebUI] Hide tables in AllStagePages for "Active Stages, Completed Stages and Failed Stages" when they are empty Related to SPARK-5228 and #4028, `AllStagesPage` also should hide the table for `ActiveStages`, `CompleteStages` and `FailedStages` when they are empty. Author: Kousuke Saruta Closes #4083 from sarutak/SPARK-5294 and squashes the following commits: a7625c1 [Kousuke Saruta] Fixed conflicts --- .../apache/spark/ui/jobs/AllStagesPage.scala | 106 ++++++++++++------ 1 file changed, 69 insertions(+), 37 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 1da7a988203db..479f967fb1541 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -59,54 +59,86 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]) val poolTable = new PoolTable(pools, parent) + val shouldShowActiveStages = activeStages.nonEmpty + val shouldShowPendingStages = pendingStages.nonEmpty + val shouldShowCompletedStages = completedStages.nonEmpty + val shouldShowFailedStages = failedStages.nonEmpty + val summary: NodeSeq =
    - {if (sc.isDefined) { - // Total duration is not meaningful unless the UI is live -
  • - Total Duration: - {UIUtils.formatDuration(now - sc.get.startTime)} -
  • - }} + { + if (sc.isDefined) { + // Total duration is not meaningful unless the UI is live +
  • + Total Duration: + {UIUtils.formatDuration(now - sc.get.startTime)} +
  • + } + }
  • Scheduling Mode: {listener.schedulingMode.map(_.toString).getOrElse("Unknown")}
  • -
  • - Active Stages: - {activeStages.size} -
  • -
  • - Pending Stages: - {pendingStages.size} -
  • -
  • - Completed Stages: - {numCompletedStages} -
  • -
  • - Failed Stages: - {numFailedStages} -
  • + { + if (shouldShowActiveStages) { +
  • + Active Stages: + {activeStages.size} +
  • + } + } + { + if (shouldShowPendingStages) { +
  • + Pending Stages: + {pendingStages.size} +
  • + } + } + { + if (shouldShowCompletedStages) { +
  • + Completed Stages: + {numCompletedStages} +
  • + } + } + { + if (shouldShowFailedStages) { +
  • + Failed Stages: + {numFailedStages} +
  • + } + }
- val content = summary ++ - {if (sc.isDefined && isFairScheduler) { -

{pools.size} Fair Scheduler Pools

++ poolTable.toNodeSeq - } else { - Seq[Node]() - }} ++ -

Active Stages ({activeStages.size})

++ - activeStagesTable.toNodeSeq ++ -

Pending Stages ({pendingStages.size})

++ - pendingStagesTable.toNodeSeq ++ -

Completed Stages ({numCompletedStages})

++ - completedStagesTable.toNodeSeq ++ -

Failed Stages ({numFailedStages})

++ + var content = summary ++ + { + if (sc.isDefined && isFairScheduler) { +

{pools.size} Fair Scheduler Pools

++ poolTable.toNodeSeq + } else { + Seq[Node]() + } + } + if (shouldShowActiveStages) { + content ++=

Active Stages ({activeStages.size})

++ + activeStagesTable.toNodeSeq + } + if (shouldShowPendingStages) { + content ++=

Pending Stages ({pendingStages.size}

++ + pendingStagesTable.toNodeSeq + } + if (shouldShowCompletedStages) { + content ++=

Completed Stages ({numCompletedStages})

++ + completedStagesTable.toNodeSeq + } + if (shouldShowFailedStages) { + content ++=

Failed Stages ({numFailedStages})

++ failedStagesTable.toNodeSeq - + } UIUtils.headerSparkPage("Spark Stages (for all jobs)", content, parent) } } From bad6c5721167153d7ed834b49f87bf2980c6ed67 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 20 Jan 2015 22:44:58 -0800 Subject: [PATCH 08/27] [SPARK-5275] [Streaming] include python source code Include the python source code into assembly jar. cc mengxr pwendell Author: Davies Liu Closes #4128 from davies/build_streaming2 and squashes the following commits: 546af4c [Davies Liu] fix indent 48859b2 [Davies Liu] include python source code --- streaming/pom.xml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/streaming/pom.xml b/streaming/pom.xml index d3c6d0347a622..22b0d714b57f6 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -96,5 +96,13 @@ + + + ../python + + pyspark/streaming/*.py + + + From ec5b0f2cef4b30047c7f88bdc00d10b6aa308124 Mon Sep 17 00:00:00 2001 From: Kannan Rajah Date: Tue, 20 Jan 2015 23:34:04 -0800 Subject: [PATCH 09/27] [HOTFIX] Update pom.xml to pull MapR's Hadoop version 2.4.1. Author: Kannan Rajah Closes #4108 from rkannan82/master and squashes the following commits: eca095b [Kannan Rajah] Update pom.xml to pull MapR's Hadoop version 2.4.1. --- pom.xml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index f4466e56c2a53..b993391b15042 100644 --- a/pom.xml +++ b/pom.xml @@ -1507,7 +1507,7 @@ mapr3 1.0.3-mapr-3.0.3 - 2.3.0-mapr-4.0.0-FCS + 2.4.1-mapr-1408 0.94.17-mapr-1405 3.4.5-mapr-1406 @@ -1516,8 +1516,8 @@ mapr4 - 2.3.0-mapr-4.0.0-FCS - 2.3.0-mapr-4.0.0-FCS + 2.4.1-mapr-1408 + 2.4.1-mapr-1408 0.94.17-mapr-1405-4.0.0-FCS 3.4.5-mapr-1406 From 424d8c6ffff42e4231cc1088b7e69e3c0f5e6b56 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 20 Jan 2015 23:37:47 -0800 Subject: [PATCH 10/27] [SPARK-5297][Streaming] Fix Java file stream type erasure problem Current Java file stream doesn't support custom key/value type because of loss of type information, details can be seen in [SPARK-5297](https://issues.apache.org/jira/browse/SPARK-5297). Fix this problem by getting correct `ClassTag` from `Class[_]`. Author: jerryshao Closes #4101 from jerryshao/SPARK-5297 and squashes the following commits: e022ca3 [jerryshao] Add Mima exclusion ecd61b8 [jerryshao] Fix Java fileInputStream type erasure problem --- project/MimaExcludes.scala | 4 ++ .../api/java/JavaStreamingContext.scala | 53 +++++++++++--- .../apache/spark/streaming/JavaAPISuite.java | 70 +++++++++++++++++-- 3 files changed, 112 insertions(+), 15 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 95fef23ee4f39..127973b658190 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -86,6 +86,10 @@ object MimaExcludes { // SPARK-5270 ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.isEmpty") + ) ++ Seq( + // SPARK-5297 Java FileStream do not work with custom key/values + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream") ) case v if v.startsWith("1.2") => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index d8695b8e05962..9a2254bcdc1f7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -17,14 +17,15 @@ package org.apache.spark.streaming.api.java +import java.lang.{Boolean => JBoolean} +import java.io.{Closeable, InputStream} +import java.util.{List => JList, Map => JMap} import scala.collection.JavaConversions._ import scala.reflect.ClassTag -import java.io.{Closeable, InputStream} -import java.util.{List => JList, Map => JMap} - import akka.actor.{Props, SupervisorStrategy} +import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark.{SparkConf, SparkContext} @@ -250,21 +251,53 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Files must be written to the monitored directory by "moving" them from another * location within the same file system. File names starting with . are ignored. * @param directory HDFS directory to monitor for new file + * @param kClass class of key for reading HDFS file + * @param vClass class of value for reading HDFS file + * @param fClass class of input format for reading HDFS file * @tparam K Key type for reading HDFS file * @tparam V Value type for reading HDFS file * @tparam F Input format for reading HDFS file */ def fileStream[K, V, F <: NewInputFormat[K, V]]( - directory: String): JavaPairInputDStream[K, V] = { - implicit val cmk: ClassTag[K] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] - implicit val cmv: ClassTag[V] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] - implicit val cmf: ClassTag[F] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[F]] + directory: String, + kClass: Class[K], + vClass: Class[V], + fClass: Class[F]): JavaPairInputDStream[K, V] = { + implicit val cmk: ClassTag[K] = ClassTag(kClass) + implicit val cmv: ClassTag[V] = ClassTag(vClass) + implicit val cmf: ClassTag[F] = ClassTag(fClass) ssc.fileStream[K, V, F](directory) } + /** + * Create an input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them using the given key-value types and input format. + * Files must be written to the monitored directory by "moving" them from another + * location within the same file system. File names starting with . are ignored. + * @param directory HDFS directory to monitor for new file + * @param kClass class of key for reading HDFS file + * @param vClass class of value for reading HDFS file + * @param fClass class of input format for reading HDFS file + * @param filter Function to filter paths to process + * @param newFilesOnly Should process only new files and ignore existing files in the directory + * @tparam K Key type for reading HDFS file + * @tparam V Value type for reading HDFS file + * @tparam F Input format for reading HDFS file + */ + def fileStream[K, V, F <: NewInputFormat[K, V]]( + directory: String, + kClass: Class[K], + vClass: Class[V], + fClass: Class[F], + filter: JFunction[Path, JBoolean], + newFilesOnly: Boolean): JavaPairInputDStream[K, V] = { + implicit val cmk: ClassTag[K] = ClassTag(kClass) + implicit val cmv: ClassTag[V] = ClassTag(vClass) + implicit val cmf: ClassTag[F] = ClassTag(fClass) + def fn = (x: Path) => filter.call(x).booleanValue() + ssc.fileStream[K, V, F](directory, fn, newFilesOnly) + } + /** * Create an input stream with any arbitrary user implemented actor receiver. * @param props Props object defining creation of the actor diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 12cc0de7509d6..d92e7fe899a09 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -17,13 +17,20 @@ package org.apache.spark.streaming; +import java.io.*; +import java.lang.Iterable; +import java.nio.charset.Charset; +import java.util.*; + +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; import scala.Tuple2; import org.junit.Assert; +import static org.junit.Assert.*; import org.junit.Test; -import java.io.*; -import java.util.*; -import java.lang.Iterable; import com.google.common.base.Optional; import com.google.common.collect.Lists; @@ -1743,13 +1750,66 @@ public Iterable call(InputStream in) throws IOException { StorageLevel.MEMORY_ONLY()); } + @SuppressWarnings("unchecked") + @Test + public void testTextFileStream() throws IOException { + File testDir = Utils.createTempDir(System.getProperty("java.io.tmpdir")); + List> expected = fileTestPrepare(testDir); + + JavaDStream input = ssc.textFileStream(testDir.toString()); + JavaTestUtils.attachTestOutputStream(input); + List> result = JavaTestUtils.runStreams(ssc, 1, 1); + + assertOrderInvariantEquals(expected, result); + } + + @SuppressWarnings("unchecked") @Test - public void testTextFileStream() { - JavaDStream test = ssc.textFileStream("/tmp/foo"); + public void testFileStream() throws IOException { + File testDir = Utils.createTempDir(System.getProperty("java.io.tmpdir")); + List> expected = fileTestPrepare(testDir); + + JavaPairInputDStream inputStream = ssc.fileStream( + testDir.toString(), + LongWritable.class, + Text.class, + TextInputFormat.class, + new Function() { + @Override + public Boolean call(Path v1) throws Exception { + return Boolean.TRUE; + } + }, + true); + + JavaDStream test = inputStream.map( + new Function, String>() { + @Override + public String call(Tuple2 v1) throws Exception { + return v1._2().toString(); + } + }); + + JavaTestUtils.attachTestOutputStream(test); + List> result = JavaTestUtils.runStreams(ssc, 1, 1); + + assertOrderInvariantEquals(expected, result); } @Test public void testRawSocketStream() { JavaReceiverInputDStream test = ssc.rawSocketStream("localhost", 12345); } + + private List> fileTestPrepare(File testDir) throws IOException { + File existingFile = new File(testDir, "0"); + Files.write("0\n", existingFile, Charset.forName("UTF-8")); + assertTrue(existingFile.setLastModified(1000) && existingFile.lastModified() == 1000); + + List> expected = Arrays.asList( + Arrays.asList("0") + ); + + return expected; + } } From 8c06a5faacfc71050461273133b9cf9a9dd8986f Mon Sep 17 00:00:00 2001 From: WangTao Date: Wed, 21 Jan 2015 09:42:30 -0600 Subject: [PATCH 11/27] [SPARK-5336][YARN]spark.executor.cores must not be less than spark.task.cpus https://issues.apache.org/jira/browse/SPARK-5336 Author: WangTao Author: WangTaoTheTonic Closes #4123 from WangTaoTheTonic/SPARK-5336 and squashes the following commits: 6c9676a [WangTao] Update ClientArguments.scala 9632d3a [WangTaoTheTonic] minor comment fix d03d6fa [WangTaoTheTonic] import ordering should be alphabetical' 3112af9 [WangTao] spark.executor.cores must not be less than spark.task.cpus --- .../org/apache/spark/ExecutorAllocationManager.scala | 2 +- .../org/apache/spark/scheduler/TaskSchedulerImpl.scala | 2 +- .../org/apache/spark/deploy/yarn/ClientArguments.scala | 10 +++++++--- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index a0ee2a7cbb2a2..b28da192c1c0d 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -158,7 +158,7 @@ private[spark] class ExecutorAllocationManager( "shuffle service. You may enable this through spark.shuffle.service.enabled.") } if (tasksPerExecutor == 0) { - throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.cores") + throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index a1dfb01062591..33a7aae5d3fcd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -168,7 +168,7 @@ private[spark] class TaskSchedulerImpl( if (!hasLaunchedTask) { logWarning("Initial job has not accepted any resources; " + "check your cluster UI to ensure that workers are registered " + - "and have sufficient memory") + "and have sufficient resources") } else { this.cancel() } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 79bead77ba6e4..f96b245512271 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -19,9 +19,9 @@ package org.apache.spark.deploy.yarn import scala.collection.mutable.ArrayBuffer -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import org.apache.spark.util.{Utils, IntParam, MemoryParam} +import org.apache.spark.util.{IntParam, MemoryParam, Utils} // TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware ! private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) { @@ -95,6 +95,10 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) throw new IllegalArgumentException( "You must specify at least 1 executor!\n" + getUsageMessage()) } + if (executorCores < sparkConf.getInt("spark.task.cpus", 1)) { + throw new SparkException("Executor cores must not be less than " + + "spark.task.cpus.") + } if (isClusterMode) { for (key <- Seq(amMemKey, amMemOverheadKey, amCoresKey)) { if (sparkConf.contains(key)) { @@ -222,7 +226,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) | --arg ARG Argument to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. | --num-executors NUM Number of executors to start (Default: 2) - | --executor-cores NUM Number of cores for the executors (Default: 1). + | --executor-cores NUM Number of cores per executor (Default: 1). | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb) | --driver-cores NUM Number of cores used by the driver (Default: 1). | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) From 2eeada373e59d63b774ba92eb5d75fcd3a1cf8f4 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Wed, 21 Jan 2015 10:31:54 -0600 Subject: [PATCH 12/27] SPARK-1714. Take advantage of AMRMClient APIs to simplify logic in YarnA... ...llocator The goal of this PR is to simplify YarnAllocator as much as possible and get it up to the level of code quality we see in the rest of Spark. In service of this, it does a few things: * Uses AMRMClient APIs for matching containers to requests. * Adds calls to AMRMClient.removeContainerRequest so that, when we use a container, we don't end up requesting it again. * Removes YarnAllocator's host->rack cache. YARN's RackResolver already does this caching, so this is redundant. * Adds tests for basic YarnAllocator functionality. * Breaks up the allocateResources method, which was previously nearly 300 lines. * A little bit of stylistic cleanup. * Fixes a bug that causes three times the requests to be filed when preferred host locations are given. The patch is lossy. In particular, it loses the logic for trying to avoid containers bunching up on nodes. As I understand it, the logic that's gone is: * If, in a single response from the RM, we receive a set of containers on a node, and prefer some number of containers on that node greater than 0 but less than the number we received, give back the delta between what we preferred and what we received. This seems like a weird way to avoid bunching E.g. it does nothing to avoid bunching when we don't request containers on particular nodes. Author: Sandy Ryza Closes #3765 from sryza/sandy-spark-1714 and squashes the following commits: 32a5942 [Sandy Ryza] Muffle RackResolver logs 74f56dd [Sandy Ryza] Fix a couple comments and simplify requestTotalExecutors 60ea4bd [Sandy Ryza] Fix scalastyle ca35b53 [Sandy Ryza] Simplify further e9cf8a6 [Sandy Ryza] Fix YarnClusterSuite 257acf3 [Sandy Ryza] Remove locality stuff and more cleanup 59a3c5e [Sandy Ryza] Take out rack stuff 5f72fd5 [Sandy Ryza] Further documentation and cleanup 89edd68 [Sandy Ryza] SPARK-1714. Take advantage of AMRMClient APIs to simplify logic in YarnAllocator --- .../apache/spark/log4j-defaults.properties | 1 + .../spark/deploy/yarn/YarnAllocator.scala | 733 ++++++------------ .../spark/deploy/yarn/YarnRMClient.scala | 3 +- .../deploy/yarn/YarnSparkHadoopUtil.scala | 41 +- .../cluster/YarnClientClusterScheduler.scala | 5 +- .../cluster/YarnClusterScheduler.scala | 6 +- .../deploy/yarn/YarnAllocatorSuite.scala | 150 +++- 7 files changed, 389 insertions(+), 550 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 89eec7d4b7f61..c99a61f63ea2b 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -10,3 +10,4 @@ log4j.logger.org.eclipse.jetty=WARN log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO +log4j.logger.org.apache.hadoop.yarn.util.RackResolver=WARN diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index de65ef23ad1ce..4c35b60c57df3 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -17,8 +17,8 @@ package org.apache.spark.deploy.yarn +import java.util.Collections import java.util.concurrent._ -import java.util.concurrent.atomic.AtomicInteger import java.util.regex.Pattern import scala.collection.JavaConversions._ @@ -28,33 +28,26 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records._ -import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.apache.hadoop.yarn.util.Records +import org.apache.hadoop.yarn.util.RackResolver -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import org.apache.spark.scheduler.{SplitInfo, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -object AllocationType extends Enumeration { - type AllocationType = Value - val HOST, RACK, ANY = Value -} - -// TODO: -// Too many params. -// Needs to be mt-safe -// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive - should -// make it more proactive and decoupled. - -// Note that right now, we assume all node asks as uniform in terms of capabilities and priority -// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for -// more info on how we are requesting for containers. - /** - * Acquires resources for executors from a ResourceManager and launches executors in new containers. + * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding + * what to do with containers when YARN fulfills these requests. + * + * This class makes use of YARN's AMRMClient APIs. We interact with the AMRMClient in three ways: + * * Making our resource needs known, which updates local bookkeeping about containers requested. + * * Calling "allocate", which syncs our local container requests with the RM, and returns any + * containers that YARN has granted to us. This also functions as a heartbeat. + * * Processing the containers granted to us to possibly launch executors inside of them. + * + * The public methods of this class are thread-safe. All methods that mutate state are + * synchronized. */ private[yarn] class YarnAllocator( conf: Configuration, @@ -62,50 +55,42 @@ private[yarn] class YarnAllocator( amClient: AMRMClient[ContainerRequest], appAttemptId: ApplicationAttemptId, args: ApplicationMasterArguments, - preferredNodes: collection.Map[String, collection.Set[SplitInfo]], securityMgr: SecurityManager) extends Logging { import YarnAllocator._ - // These three are locked on allocatedHostToContainersMap. Complementary data structures - // allocatedHostToContainersMap : containers which are running : host, Set - // allocatedContainerToHostMap: container to host mapping. - private val allocatedHostToContainersMap = - new HashMap[String, collection.mutable.Set[ContainerId]]() + // These two complementary data structures are locked on allocatedHostToContainersMap. + // Visible for testing. + val allocatedHostToContainersMap = + new HashMap[String, collection.mutable.Set[ContainerId]] + val allocatedContainerToHostMap = new HashMap[ContainerId, String] - private val allocatedContainerToHostMap = new HashMap[ContainerId, String]() + // Containers that we no longer care about. We've either already told the RM to release them or + // will on the next heartbeat. Containers get removed from this map after the RM tells us they've + // completed. + private val releasedContainers = Collections.newSetFromMap[ContainerId]( + new ConcurrentHashMap[ContainerId, java.lang.Boolean]) - // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an - // allocated node) - // As with the two data structures above, tightly coupled with them, and to be locked on - // allocatedHostToContainersMap - private val allocatedRackCount = new HashMap[String, Int]() + @volatile private var numExecutorsRunning = 0 + // Used to generate a unique ID per executor + private var executorIdCounter = 0 + @volatile private var numExecutorsFailed = 0 - // Containers to be released in next request to RM - private val releasedContainers = new ConcurrentHashMap[ContainerId, Boolean] - - // Number of container requests that have been sent to, but not yet allocated by the - // ApplicationMaster. - private val numPendingAllocate = new AtomicInteger() - private val numExecutorsRunning = new AtomicInteger() - // Used to generate a unique id per executor - private val executorIdCounter = new AtomicInteger() - private val numExecutorsFailed = new AtomicInteger() - - private var maxExecutors = args.numExecutors + @volatile private var maxExecutors = args.numExecutors // Keep track of which container is running which executor to remove the executors later private val executorIdToContainer = new HashMap[String, Container] + // Executor memory in MB. protected val executorMemory = args.executorMemory - protected val executorCores = args.executorCores - protected val (preferredHostToCount, preferredRackToCount) = - generateNodeToWeight(conf, preferredNodes) - - // Additional memory overhead - in mb. + // Additional memory overhead. protected val memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)) + // Number of cores per executor. + protected val executorCores = args.executorCores + // Resource capability requested for each executors + private val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) private val launcherPool = new ThreadPoolExecutor( // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue @@ -115,26 +100,34 @@ private[yarn] class YarnAllocator( new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build()) launcherPool.allowCoreThreadTimeOut(true) - def getNumExecutorsRunning: Int = numExecutorsRunning.intValue + private val driverUrl = "akka.tcp://sparkDriver@%s:%s/user/%s".format( + sparkConf.get("spark.driver.host"), + sparkConf.get("spark.driver.port"), + CoarseGrainedSchedulerBackend.ACTOR_NAME) + + // For testing + private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) - def getNumExecutorsFailed: Int = numExecutorsFailed.intValue + def getNumExecutorsRunning: Int = numExecutorsRunning + + def getNumExecutorsFailed: Int = numExecutorsFailed + + /** + * Number of container requests that have not yet been fulfilled. + */ + def getNumPendingAllocate: Int = getNumPendingAtLocation(ANY_HOST) + + /** + * Number of container requests at the given location that have not yet been fulfilled. + */ + private def getNumPendingAtLocation(location: String): Int = + amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).map(_.size).sum /** * Request as many executors from the ResourceManager as needed to reach the desired total. - * This takes into account executors already running or pending. */ def requestTotalExecutors(requestedTotal: Int): Unit = synchronized { - val currentTotal = numPendingAllocate.get + numExecutorsRunning.get - if (requestedTotal > currentTotal) { - maxExecutors += (requestedTotal - currentTotal) - // We need to call `allocateResources` here to avoid the following race condition: - // If we request executors twice before `allocateResources` is called, then we will end up - // double counting the number requested because `numPendingAllocate` is not updated yet. - allocateResources() - } else { - logInfo(s"Not allocating more executors because there are already $currentTotal " + - s"(application requested $requestedTotal total)") - } + maxExecutors = requestedTotal } /** @@ -144,7 +137,7 @@ private[yarn] class YarnAllocator( if (executorIdToContainer.contains(executorId)) { val container = executorIdToContainer.remove(executorId).get internalReleaseContainer(container) - numExecutorsRunning.decrementAndGet() + numExecutorsRunning -= 1 maxExecutors -= 1 assert(maxExecutors >= 0, "Allocator killed more executors than are allocated!") } else { @@ -153,498 +146,236 @@ private[yarn] class YarnAllocator( } /** - * Allocate missing containers based on the number of executors currently pending and running. + * Request resources such that, if YARN gives us all we ask for, we'll have a number of containers + * equal to maxExecutors. * - * This method prioritizes the allocated container responses from the RM based on node and - * rack locality. Additionally, it releases any extra containers allocated for this application - * but are not needed. This must be synchronized because variables read in this block are - * mutated by other methods. + * Deal with any containers YARN has granted to us by possibly launching executors in them. + * + * This must be synchronized because variables read in this method are mutated by other methods. */ def allocateResources(): Unit = synchronized { - val missing = maxExecutors - numPendingAllocate.get() - numExecutorsRunning.get() + val numPendingAllocate = getNumPendingAllocate + val missing = maxExecutors - numPendingAllocate - numExecutorsRunning if (missing > 0) { - val totalExecutorMemory = executorMemory + memoryOverhead - numPendingAllocate.addAndGet(missing) - logInfo(s"Will allocate $missing executor containers, each with $totalExecutorMemory MB " + - s"memory including $memoryOverhead MB overhead") - } else { - logDebug("Empty allocation request ...") + logInfo(s"Will request $missing executor containers, each with ${resource.getVirtualCores} " + + s"cores and ${resource.getMemory} MB memory including $memoryOverhead MB overhead") } - val allocateResponse = allocateContainers(missing) + addResourceRequests(missing) + val progressIndicator = 0.1f + // Poll the ResourceManager. This doubles as a heartbeat if there are no pending container + // requests. + val allocateResponse = amClient.allocate(progressIndicator) + val allocatedContainers = allocateResponse.getAllocatedContainers() if (allocatedContainers.size > 0) { - var numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * allocatedContainers.size) - - if (numPendingAllocateNow < 0) { - numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * numPendingAllocateNow) - } - - logDebug(""" - Allocated containers: %d - Current executor count: %d - Containers released: %s - Cluster resources: %s - """.format( + logDebug("Allocated containers: %d. Current executor count: %d. Cluster resources: %s." + .format( allocatedContainers.size, - numExecutorsRunning.get(), - releasedContainers, + numExecutorsRunning, allocateResponse.getAvailableResources)) - val hostToContainers = new HashMap[String, ArrayBuffer[Container]]() - - for (container <- allocatedContainers) { - if (isResourceConstraintSatisfied(container)) { - // Add the accepted `container` to the host's list of already accepted, - // allocated containers - val host = container.getNodeId.getHost - val containersForHost = hostToContainers.getOrElseUpdate(host, - new ArrayBuffer[Container]()) - containersForHost += container - } else { - // Release container, since it doesn't satisfy resource constraints. - internalReleaseContainer(container) - } - } - - // Find the appropriate containers to use. - // TODO: Cleanup this group-by... - val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]() - val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]() - val offRackContainers = new HashMap[String, ArrayBuffer[Container]]() - - for (candidateHost <- hostToContainers.keySet) { - val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0) - val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost) - - val remainingContainersOpt = hostToContainers.get(candidateHost) - assert(remainingContainersOpt.isDefined) - var remainingContainers = remainingContainersOpt.get - - if (requiredHostCount >= remainingContainers.size) { - // Since we have <= required containers, add all remaining containers to - // `dataLocalContainers`. - dataLocalContainers.put(candidateHost, remainingContainers) - // There are no more free containers remaining. - remainingContainers = null - } else if (requiredHostCount > 0) { - // Container list has more containers than we need for data locality. - // Split the list into two: one based on the data local container count, - // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining - // containers. - val (dataLocal, remaining) = remainingContainers.splitAt( - remainingContainers.size - requiredHostCount) - dataLocalContainers.put(candidateHost, dataLocal) - - // Invariant: remainingContainers == remaining - - // YARN has a nasty habit of allocating a ton of containers on a host - discourage this. - // Add each container in `remaining` to list of containers to release. If we have an - // insufficient number of containers, then the next allocation cycle will reallocate - // (but won't treat it as data local). - // TODO(harvey): Rephrase this comment some more. - for (container <- remaining) internalReleaseContainer(container) - remainingContainers = null - } - - // For rack local containers - if (remainingContainers != null) { - val rack = YarnSparkHadoopUtil.lookupRack(conf, candidateHost) - if (rack != null) { - val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0) - val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - - rackLocalContainers.getOrElse(rack, List()).size - - if (requiredRackCount >= remainingContainers.size) { - // Add all remaining containers to to `dataLocalContainers`. - dataLocalContainers.put(rack, remainingContainers) - remainingContainers = null - } else if (requiredRackCount > 0) { - // Container list has more containers that we need for data locality. - // Split the list into two: one based on the data local container count, - // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining - // containers. - val (rackLocal, remaining) = remainingContainers.splitAt( - remainingContainers.size - requiredRackCount) - val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, - new ArrayBuffer[Container]()) - - existingRackLocal ++= rackLocal - - remainingContainers = remaining - } - } - } - - if (remainingContainers != null) { - // Not all containers have been consumed - add them to the list of off-rack containers. - offRackContainers.put(candidateHost, remainingContainers) - } - } - - // Now that we have split the containers into various groups, go through them in order: - // first host-local, then rack-local, and finally off-rack. - // Note that the list we create below tries to ensure that not all containers end up within - // a host if there is a sufficiently large number of hosts/containers. - val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size) - allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers) - allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers) - allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers) - - // Run each of the allocated containers. - for (container <- allocatedContainersToProcess) { - val numExecutorsRunningNow = numExecutorsRunning.incrementAndGet() - val executorHostname = container.getNodeId.getHost - val containerId = container.getId - - val executorMemoryOverhead = (executorMemory + memoryOverhead) - assert(container.getResource.getMemory >= executorMemoryOverhead) - - if (numExecutorsRunningNow > maxExecutors) { - logInfo("""Ignoring container %s at host %s, since we already have the required number of - containers for it.""".format(containerId, executorHostname)) - internalReleaseContainer(container) - numExecutorsRunning.decrementAndGet() - } else { - val executorId = executorIdCounter.incrementAndGet().toString - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( - SparkEnv.driverActorSystemName, - sparkConf.get("spark.driver.host"), - sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) - - logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) - executorIdToContainer(executorId) = container - - // To be safe, remove the container from `releasedContainers`. - releasedContainers.remove(containerId) - - val rack = YarnSparkHadoopUtil.lookupRack(conf, executorHostname) - allocatedHostToContainersMap.synchronized { - val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, - new HashSet[ContainerId]()) - - containerSet += containerId - allocatedContainerToHostMap.put(containerId, executorHostname) - - if (rack != null) { - allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) - } - } - logInfo("Launching ExecutorRunnable. driverUrl: %s, executorHostname: %s".format( - driverUrl, executorHostname)) - val executorRunnable = new ExecutorRunnable( - container, - conf, - sparkConf, - driverUrl, - executorId, - executorHostname, - executorMemory, - executorCores, - appAttemptId.getApplicationId.toString, - securityMgr) - launcherPool.execute(executorRunnable) - } - } - logDebug(""" - Finished allocating %s containers (from %s originally). - Current number of executors running: %d, - Released containers: %s - """.format( - allocatedContainersToProcess, - allocatedContainers, - numExecutorsRunning.get(), - releasedContainers)) + handleAllocatedContainers(allocatedContainers) } val completedContainers = allocateResponse.getCompletedContainersStatuses() if (completedContainers.size > 0) { logDebug("Completed %d containers".format(completedContainers.size)) - for (completedContainer <- completedContainers) { - val containerId = completedContainer.getContainerId - - if (releasedContainers.containsKey(containerId)) { - // Already marked the container for release, so remove it from - // `releasedContainers`. - releasedContainers.remove(containerId) - } else { - // Decrement the number of executors running. The next iteration of - // the ApplicationMaster's reporting thread will take care of allocating. - numExecutorsRunning.decrementAndGet() - logInfo("Completed container %s (state: %s, exit status: %s)".format( - containerId, - completedContainer.getState, - completedContainer.getExitStatus)) - // Hadoop 2.2.X added a ContainerExitStatus we should switch to use - // there are some exit status' we shouldn't necessarily count against us, but for - // now I think its ok as none of the containers are expected to exit - if (completedContainer.getExitStatus == -103) { // vmem limit exceeded - logWarning(memLimitExceededLogMessage( - completedContainer.getDiagnostics, - VMEM_EXCEEDED_PATTERN)) - } else if (completedContainer.getExitStatus == -104) { // pmem limit exceeded - logWarning(memLimitExceededLogMessage( - completedContainer.getDiagnostics, - PMEM_EXCEEDED_PATTERN)) - } else if (completedContainer.getExitStatus != 0) { - logInfo("Container marked as failed: " + containerId + - ". Exit status: " + completedContainer.getExitStatus + - ". Diagnostics: " + completedContainer.getDiagnostics) - numExecutorsFailed.incrementAndGet() - } - } + processCompletedContainers(completedContainers) - allocatedHostToContainersMap.synchronized { - if (allocatedContainerToHostMap.containsKey(containerId)) { - val hostOpt = allocatedContainerToHostMap.get(containerId) - assert(hostOpt.isDefined) - val host = hostOpt.get - - val containerSetOpt = allocatedHostToContainersMap.get(host) - assert(containerSetOpt.isDefined) - val containerSet = containerSetOpt.get - - containerSet.remove(containerId) - if (containerSet.isEmpty) { - allocatedHostToContainersMap.remove(host) - } else { - allocatedHostToContainersMap.update(host, containerSet) - } - - allocatedContainerToHostMap.remove(containerId) - - // TODO: Move this part outside the synchronized block? - val rack = YarnSparkHadoopUtil.lookupRack(conf, host) - if (rack != null) { - val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1 - if (rackCount > 0) { - allocatedRackCount.put(rack, rackCount) - } else { - allocatedRackCount.remove(rack) - } - } - } - } - } - logDebug(""" - Finished processing %d completed containers. - Current number of executors running: %d, - Released containers: %s - """.format( - completedContainers.size, - numExecutorsRunning.get(), - releasedContainers)) + logDebug("Finished processing %d completed containers. Current running executor count: %d." + .format(completedContainers.size, numExecutorsRunning)) } } - private def allocatedContainersOnHost(host: String): Int = { - allocatedHostToContainersMap.synchronized { - allocatedHostToContainersMap.getOrElse(host, Set()).size + /** + * Request numExecutors additional containers from YARN. Visible for testing. + */ + def addResourceRequests(numExecutors: Int): Unit = { + for (i <- 0 until numExecutors) { + val request = new ContainerRequest(resource, null, null, RM_REQUEST_PRIORITY) + amClient.addContainerRequest(request) + val nodes = request.getNodes + val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.last + logInfo("Container request (host: %s, capability: %s".format(hostStr, resource)) } } - private def allocatedContainersOnRack(rack: String): Int = { - allocatedHostToContainersMap.synchronized { - allocatedRackCount.getOrElse(rack, 0) + /** + * Handle containers granted by the RM by launching executors on them. + * + * Due to the way the YARN allocation protocol works, certain healthy race conditions can result + * in YARN granting containers that we no longer need. In this case, we release them. + * + * Visible for testing. + */ + def handleAllocatedContainers(allocatedContainers: Seq[Container]): Unit = { + val containersToUse = new ArrayBuffer[Container](allocatedContainers.size) + + // Match incoming requests by host + val remainingAfterHostMatches = new ArrayBuffer[Container] + for (allocatedContainer <- allocatedContainers) { + matchContainerToRequest(allocatedContainer, allocatedContainer.getNodeId.getHost, + containersToUse, remainingAfterHostMatches) } - } - - private def isResourceConstraintSatisfied(container: Container): Boolean = { - container.getResource.getMemory >= (executorMemory + memoryOverhead) - } - // A simple method to copy the split info map. - private def generateNodeToWeight( - conf: Configuration, - input: collection.Map[String, collection.Set[SplitInfo]]) - : (Map[String, Int], Map[String, Int]) = { - if (input == null) { - return (Map[String, Int](), Map[String, Int]()) + // Match remaining by rack + val remainingAfterRackMatches = new ArrayBuffer[Container] + for (allocatedContainer <- remainingAfterHostMatches) { + val rack = RackResolver.resolve(conf, allocatedContainer.getNodeId.getHost).getNetworkLocation + matchContainerToRequest(allocatedContainer, rack, containersToUse, + remainingAfterRackMatches) } - val hostToCount = new HashMap[String, Int] - val rackToCount = new HashMap[String, Int] - - for ((host, splits) <- input) { - val hostCount = hostToCount.getOrElse(host, 0) - hostToCount.put(host, hostCount + splits.size) + // Assign remaining that are neither node-local nor rack-local + val remainingAfterOffRackMatches = new ArrayBuffer[Container] + for (allocatedContainer <- remainingAfterRackMatches) { + matchContainerToRequest(allocatedContainer, ANY_HOST, containersToUse, + remainingAfterOffRackMatches) + } - val rack = YarnSparkHadoopUtil.lookupRack(conf, host) - if (rack != null) { - val rackCount = rackToCount.getOrElse(host, 0) - rackToCount.put(host, rackCount + splits.size) + if (!remainingAfterOffRackMatches.isEmpty) { + logDebug(s"Releasing ${remainingAfterOffRackMatches.size} unneeded containers that were " + + s"allocated to us") + for (container <- remainingAfterOffRackMatches) { + internalReleaseContainer(container) } } - (hostToCount.toMap, rackToCount.toMap) - } + runAllocatedContainers(containersToUse) - private def internalReleaseContainer(container: Container): Unit = { - releasedContainers.put(container.getId(), true) - amClient.releaseAssignedContainer(container.getId()) + logInfo("Received %d containers from YARN, launching executors on %d of them." + .format(allocatedContainers.size, containersToUse.size)) } /** - * Called to allocate containers in the cluster. + * Looks for requests for the given location that match the given container allocation. If it + * finds one, removes the request so that it won't be submitted again. Places the container into + * containersToUse or remaining. * - * @param count Number of containers to allocate. - * If zero, should still contact RM (as a heartbeat). - * @return Response to the allocation request. + * @param allocatedContainer container that was given to us by YARN + * @location resource name, either a node, rack, or * + * @param containersToUse list of containers that will be used + * @param remaining list of containers that will not be used */ - private def allocateContainers(count: Int): AllocateResponse = { - addResourceRequests(count) - - // We have already set the container request. Poll the ResourceManager for a response. - // This doubles as a heartbeat if there are no pending container requests. - val progressIndicator = 0.1f - amClient.allocate(progressIndicator) + private def matchContainerToRequest( + allocatedContainer: Container, + location: String, + containersToUse: ArrayBuffer[Container], + remaining: ArrayBuffer[Container]): Unit = { + val matchingRequests = amClient.getMatchingRequests(allocatedContainer.getPriority, location, + allocatedContainer.getResource) + + // Match the allocation to a request + if (!matchingRequests.isEmpty) { + val containerRequest = matchingRequests.get(0).iterator.next + amClient.removeContainerRequest(containerRequest) + containersToUse += allocatedContainer + } else { + remaining += allocatedContainer + } } - private def createRackResourceRequests(hostContainers: ArrayBuffer[ContainerRequest]) - : ArrayBuffer[ContainerRequest] = { - // Generate modified racks and new set of hosts under it before issuing requests. - val rackToCounts = new HashMap[String, Int]() - - for (container <- hostContainers) { - val candidateHost = container.getNodes.last - assert(YarnSparkHadoopUtil.ANY_HOST != candidateHost) - - val rack = YarnSparkHadoopUtil.lookupRack(conf, candidateHost) - if (rack != null) { - var count = rackToCounts.getOrElse(rack, 0) - count += 1 - rackToCounts.put(rack, count) + /** + * Launches executors in the allocated containers. + */ + private def runAllocatedContainers(containersToUse: ArrayBuffer[Container]): Unit = { + for (container <- containersToUse) { + numExecutorsRunning += 1 + assert(numExecutorsRunning <= maxExecutors) + val executorHostname = container.getNodeId.getHost + val containerId = container.getId + executorIdCounter += 1 + val executorId = executorIdCounter.toString + + assert(container.getResource.getMemory >= resource.getMemory) + + logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) + + val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, + new HashSet[ContainerId]) + + containerSet += containerId + allocatedContainerToHostMap.put(containerId, executorHostname) + + val executorRunnable = new ExecutorRunnable( + container, + conf, + sparkConf, + driverUrl, + executorId, + executorHostname, + executorMemory, + executorCores, + appAttemptId.getApplicationId.toString, + securityMgr) + if (launchContainers) { + logInfo("Launching ExecutorRunnable. driverUrl: %s, executorHostname: %s".format( + driverUrl, executorHostname)) + launcherPool.execute(executorRunnable) } } - - val requestedContainers = new ArrayBuffer[ContainerRequest](rackToCounts.size) - for ((rack, count) <- rackToCounts) { - requestedContainers ++= createResourceRequests( - AllocationType.RACK, - rack, - count, - RM_REQUEST_PRIORITY) - } - - requestedContainers } - private def addResourceRequests(numExecutors: Int): Unit = { - val containerRequests: List[ContainerRequest] = - if (numExecutors <= 0) { - logDebug("numExecutors: " + numExecutors) - List() - } else if (preferredHostToCount.isEmpty) { - logDebug("host preferences is empty") - createResourceRequests( - AllocationType.ANY, - resource = null, - numExecutors, - RM_REQUEST_PRIORITY).toList + private def processCompletedContainers(completedContainers: Seq[ContainerStatus]): Unit = { + for (completedContainer <- completedContainers) { + val containerId = completedContainer.getContainerId + + if (releasedContainers.contains(containerId)) { + // Already marked the container for release, so remove it from + // `releasedContainers`. + releasedContainers.remove(containerId) } else { - // Request for all hosts in preferred nodes and for numExecutors - - // candidates.size, request by default allocation policy. - val hostContainerRequests = new ArrayBuffer[ContainerRequest](preferredHostToCount.size) - for ((candidateHost, candidateCount) <- preferredHostToCount) { - val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost) - - if (requiredCount > 0) { - hostContainerRequests ++= createResourceRequests( - AllocationType.HOST, - candidateHost, - requiredCount, - RM_REQUEST_PRIORITY) - } + // Decrement the number of executors running. The next iteration of + // the ApplicationMaster's reporting thread will take care of allocating. + numExecutorsRunning -= 1 + logInfo("Completed container %s (state: %s, exit status: %s)".format( + containerId, + completedContainer.getState, + completedContainer.getExitStatus)) + // Hadoop 2.2.X added a ContainerExitStatus we should switch to use + // there are some exit status' we shouldn't necessarily count against us, but for + // now I think its ok as none of the containers are expected to exit + if (completedContainer.getExitStatus == -103) { // vmem limit exceeded + logWarning(memLimitExceededLogMessage( + completedContainer.getDiagnostics, + VMEM_EXCEEDED_PATTERN)) + } else if (completedContainer.getExitStatus == -104) { // pmem limit exceeded + logWarning(memLimitExceededLogMessage( + completedContainer.getDiagnostics, + PMEM_EXCEEDED_PATTERN)) + } else if (completedContainer.getExitStatus != 0) { + logInfo("Container marked as failed: " + containerId + + ". Exit status: " + completedContainer.getExitStatus + + ". Diagnostics: " + completedContainer.getDiagnostics) + numExecutorsFailed += 1 } - val rackContainerRequests: List[ContainerRequest] = createRackResourceRequests( - hostContainerRequests).toList - - val anyContainerRequests = createResourceRequests( - AllocationType.ANY, - resource = null, - numExecutors, - RM_REQUEST_PRIORITY) - - val containerRequestBuffer = new ArrayBuffer[ContainerRequest]( - hostContainerRequests.size + rackContainerRequests.size + anyContainerRequests.size) - - containerRequestBuffer ++= hostContainerRequests - containerRequestBuffer ++= rackContainerRequests - containerRequestBuffer ++= anyContainerRequests - containerRequestBuffer.toList } - for (request <- containerRequests) { - amClient.addContainerRequest(request) - } + allocatedHostToContainersMap.synchronized { + if (allocatedContainerToHostMap.containsKey(containerId)) { + val host = allocatedContainerToHostMap.get(containerId).get + val containerSet = allocatedHostToContainersMap.get(host).get - for (request <- containerRequests) { - val nodes = request.getNodes - val hostStr = if (nodes == null || nodes.isEmpty) { - "Any" - } else { - nodes.last - } - logInfo("Container request (host: %s, priority: %s, capability: %s".format( - hostStr, - request.getPriority().getPriority, - request.getCapability)) - } - } + containerSet.remove(containerId) + if (containerSet.isEmpty) { + allocatedHostToContainersMap.remove(host) + } else { + allocatedHostToContainersMap.update(host, containerSet) + } - private def createResourceRequests( - requestType: AllocationType.AllocationType, - resource: String, - numExecutors: Int, - priority: Int): ArrayBuffer[ContainerRequest] = { - // If hostname is specified, then we need at least two requests - node local and rack local. - // There must be a third request, which is ANY. That will be specially handled. - requestType match { - case AllocationType.HOST => { - assert(YarnSparkHadoopUtil.ANY_HOST != resource) - val hostname = resource - val nodeLocal = constructContainerRequests( - Array(hostname), - racks = null, - numExecutors, - priority) - - // Add `hostname` to the global (singleton) host->rack mapping in YarnAllocationHandler. - YarnSparkHadoopUtil.populateRackInfo(conf, hostname) - nodeLocal - } - case AllocationType.RACK => { - val rack = resource - constructContainerRequests(hosts = null, Array(rack), numExecutors, priority) + allocatedContainerToHostMap.remove(containerId) + } } - case AllocationType.ANY => constructContainerRequests( - hosts = null, racks = null, numExecutors, priority) - case _ => throw new IllegalArgumentException( - "Unexpected/unsupported request type: " + requestType) } } - private def constructContainerRequests( - hosts: Array[String], - racks: Array[String], - numExecutors: Int, - priority: Int - ): ArrayBuffer[ContainerRequest] = { - val memoryRequest = executorMemory + memoryOverhead - val resource = Resource.newInstance(memoryRequest, executorCores) - - val prioritySetting = Records.newRecord(classOf[Priority]) - prioritySetting.setPriority(priority) - - val requests = new ArrayBuffer[ContainerRequest]() - for (i <- 0 until numExecutors) { - requests += new ContainerRequest(resource, hosts, racks, prioritySetting) - } - requests + private def internalReleaseContainer(container: Container): Unit = { + releasedContainers.add(container.getId()) + amClient.releaseAssignedContainer(container.getId()) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index b45e599588ad3..b134751366522 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -72,8 +72,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) registered = true } - new YarnAllocator(conf, sparkConf, amClient, getAttemptId(), args, - preferredNodeLocations, securityMgr) + new YarnAllocator(conf, sparkConf, amClient, getAttemptId(), args, securityMgr) } /** diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index d7cf904db1c9e..4bff846123619 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.api.records.ApplicationAccessType +import org.apache.hadoop.yarn.api.records.{Priority, ApplicationAccessType} import org.apache.hadoop.yarn.util.RackResolver import org.apache.hadoop.conf.Configuration @@ -99,13 +99,7 @@ object YarnSparkHadoopUtil { // All RM requests are issued with same priority : we do not (yet) have any distinction between // request types (like map/reduce in hadoop for example) - val RM_REQUEST_PRIORITY = 1 - - // Host to rack map - saved from allocation requests. We are expecting this not to change. - // Note that it is possible for this to change : and ResourceManager will indicate that to us via - // update response to allocate. But we are punting on handling that for now. - private val hostToRack = new ConcurrentHashMap[String, String]() - private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]() + val RM_REQUEST_PRIORITY = Priority.newInstance(1) /** * Add a path variable to the given environment map. @@ -184,37 +178,6 @@ object YarnSparkHadoopUtil { } } - def lookupRack(conf: Configuration, host: String): String = { - if (!hostToRack.contains(host)) { - populateRackInfo(conf, host) - } - hostToRack.get(host) - } - - def populateRackInfo(conf: Configuration, hostname: String) { - Utils.checkHost(hostname) - - if (!hostToRack.containsKey(hostname)) { - // If there are repeated failures to resolve, all to an ignore list. - val rackInfo = RackResolver.resolve(conf, hostname) - if (rackInfo != null && rackInfo.getNetworkLocation != null) { - val rack = rackInfo.getNetworkLocation - hostToRack.put(hostname, rack) - if (! rackToHostSet.containsKey(rack)) { - rackToHostSet.putIfAbsent(rack, - Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]())) - } - rackToHostSet.get(rack).add(hostname) - - // TODO(harvey): Figure out what this comment means... - // Since RackResolver caches, we are disabling this for now ... - } /* else { - // right ? Else we will keep calling rack resolver in case we cant resolve rack info ... - hostToRack.put(hostname, null) - } */ - } - } - def getApplicationAclsForYarn(securityMgr: SecurityManager) : Map[ApplicationAccessType, String] = { Map[ApplicationAccessType, String] ( diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala index 254774a6b839e..2fa24cc43325e 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala @@ -17,8 +17,9 @@ package org.apache.spark.scheduler.cluster +import org.apache.hadoop.yarn.util.RackResolver + import org.apache.spark._ -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils @@ -30,6 +31,6 @@ private[spark] class YarnClientClusterScheduler(sc: SparkContext) extends TaskSc // By default, rack is unknown override def getRackForHost(hostPort: String): Option[String] = { val host = Utils.parseHostPort(hostPort)._1 - Option(YarnSparkHadoopUtil.lookupRack(sc.hadoopConfiguration, host)) + Option(RackResolver.resolve(sc.hadoopConfiguration, host).getNetworkLocation) } } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index 4157ff95c2794..be55d26f1cf61 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -17,8 +17,10 @@ package org.apache.spark.scheduler.cluster +import org.apache.hadoop.yarn.util.RackResolver + import org.apache.spark._ -import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnSparkHadoopUtil} +import org.apache.spark.deploy.yarn.ApplicationMaster import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils @@ -39,7 +41,7 @@ private[spark] class YarnClusterScheduler(sc: SparkContext) extends TaskSchedule // By default, rack is unknown override def getRackForHost(hostPort: String): Option[String] = { val host = Utils.parseHostPort(hostPort)._1 - Option(YarnSparkHadoopUtil.lookupRack(sc.hadoopConfiguration, host)) + Option(RackResolver.resolve(sc.hadoopConfiguration, host).getNetworkLocation) } override def postStartHook() { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 8d184a09d64cc..024b25f9d3365 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -17,18 +17,160 @@ package org.apache.spark.deploy.yarn +import java.util.{Arrays, List => JList} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.CommonConfigurationKeysPublic +import org.apache.hadoop.net.DNSToSwitchMapping +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest + +import org.apache.spark.SecurityManager +import org.apache.spark.SparkConf +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.YarnAllocator._ -import org.scalatest.FunSuite +import org.apache.spark.scheduler.SplitInfo + +import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} + +class MockResolver extends DNSToSwitchMapping { + + override def resolve(names: JList[String]): JList[String] = { + if (names.size > 0 && names.get(0) == "host3") Arrays.asList("/rack2") + else Arrays.asList("/rack1") + } + + override def reloadCachedMappings() {} + + def reloadCachedMappings(names: JList[String]) {} +} + +class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach { + val conf = new Configuration() + conf.setClass( + CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, + classOf[MockResolver], classOf[DNSToSwitchMapping]) + + val sparkConf = new SparkConf() + sparkConf.set("spark.driver.host", "localhost") + sparkConf.set("spark.driver.port", "4040") + sparkConf.set("spark.yarn.jar", "notarealjar.jar") + sparkConf.set("spark.yarn.launchContainers", "false") + + val appAttemptId = ApplicationAttemptId.newInstance(ApplicationId.newInstance(0, 0), 0) + + // Resource returned by YARN. YARN can give larger containers than requested, so give 6 cores + // instead of the 5 requested and 3 GB instead of the 2 requested. + val containerResource = Resource.newInstance(3072, 6) + + var rmClient: AMRMClient[ContainerRequest] = _ + + var containerNum = 0 + + override def beforeEach() { + rmClient = AMRMClient.createAMRMClient() + rmClient.init(conf) + rmClient.start() + } + + override def afterEach() { + rmClient.stop() + } + + class MockSplitInfo(host: String) extends SplitInfo(null, host, null, 1, null) { + override def equals(other: Any) = false + } + + def createAllocator(maxExecutors: Int = 5): YarnAllocator = { + val args = Array( + "--num-executors", s"$maxExecutors", + "--executor-cores", "5", + "--executor-memory", "2048", + "--jar", "somejar.jar", + "--class", "SomeClass") + new YarnAllocator( + conf, + sparkConf, + rmClient, + appAttemptId, + new ApplicationMasterArguments(args), + new SecurityManager(sparkConf)) + } + + def createContainer(host: String): Container = { + val containerId = ContainerId.newInstance(appAttemptId, containerNum) + containerNum += 1 + val nodeId = NodeId.newInstance(host, 1000) + Container.newInstance(containerId, nodeId, "", containerResource, RM_REQUEST_PRIORITY, null) + } + + test("single container allocated") { + // request a single container and receive it + val handler = createAllocator() + handler.addResourceRequests(1) + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (1) + + val container = createContainer("host1") + handler.handleAllocatedContainers(Array(container)) + + handler.getNumExecutorsRunning should be (1) + handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") + handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId) + rmClient.getMatchingRequests(container.getPriority, "host1", containerResource).size should be (0) + } + + test("some containers allocated") { + // request a few containers and receive some of them + val handler = createAllocator() + handler.addResourceRequests(4) + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + val container1 = createContainer("host1") + val container2 = createContainer("host1") + val container3 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2, container3)) + + handler.getNumExecutorsRunning should be (3) + handler.allocatedContainerToHostMap.get(container1.getId).get should be ("host1") + handler.allocatedContainerToHostMap.get(container2.getId).get should be ("host1") + handler.allocatedContainerToHostMap.get(container3.getId).get should be ("host2") + handler.allocatedHostToContainersMap.get("host1").get should contain (container1.getId) + handler.allocatedHostToContainersMap.get("host1").get should contain (container2.getId) + handler.allocatedHostToContainersMap.get("host2").get should contain (container3.getId) + } + + test("receive more containers than requested") { + val handler = createAllocator(2) + handler.addResourceRequests(2) + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (2) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + val container3 = createContainer("host4") + handler.handleAllocatedContainers(Array(container1, container2, container3)) + + handler.getNumExecutorsRunning should be (2) + handler.allocatedContainerToHostMap.get(container1.getId).get should be ("host1") + handler.allocatedContainerToHostMap.get(container2.getId).get should be ("host2") + handler.allocatedContainerToHostMap.contains(container3.getId) should be (false) + handler.allocatedHostToContainersMap.get("host1").get should contain (container1.getId) + handler.allocatedHostToContainersMap.get("host2").get should contain (container2.getId) + handler.allocatedHostToContainersMap.contains("host4") should be (false) + } -class YarnAllocatorSuite extends FunSuite { test("memory exceeded diagnostic regexes") { val diagnostics = "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " + - "beyond physical memory limits. Current usage: 2.1 MB of 2 GB physical memory used; " + - "5.8 GB of 4.2 GB virtual memory used. Killing container." + "beyond physical memory limits. Current usage: 2.1 MB of 2 GB physical memory used; " + + "5.8 GB of 4.2 GB virtual memory used. Killing container." val vmemMsg = memLimitExceededLogMessage(diagnostics, VMEM_EXCEEDED_PATTERN) val pmemMsg = memLimitExceededLogMessage(diagnostics, PMEM_EXCEEDED_PATTERN) assert(vmemMsg.contains("5.8 GB of 4.2 GB virtual memory used.")) assert(pmemMsg.contains("2.1 MB of 2 GB physical memory used.")) } + } From aa1e22b17b4ce885febe6970a2451c7d17d0acfb Mon Sep 17 00:00:00 2001 From: Reza Zadeh Date: Wed, 21 Jan 2015 09:48:38 -0800 Subject: [PATCH 13/27] [MLlib] [SPARK-5301] Missing conversions and operations on IndexedRowMatrix and CoordinateMatrix * Transpose is missing from CoordinateMatrix (this is cheap to compute, so it should be there) * IndexedRowMatrix should be convertable to CoordinateMatrix (conversion added) Tests for both added. Author: Reza Zadeh Closes #4089 from rezazadeh/matutils and squashes the following commits: ec5238b [Reza Zadeh] Array -> Iterator to avoid temp array 3ce0b5d [Reza Zadeh] Array -> Iterator bbc907a [Reza Zadeh] Use 'i' for index, and zipWithIndex cb10ae5 [Reza Zadeh] remove unnecessary import a7ae048 [Reza Zadeh] Missing linear algebra utilities --- .../linalg/distributed/CoordinateMatrix.scala | 5 +++++ .../linalg/distributed/IndexedRowMatrix.scala | 17 +++++++++++++++++ .../distributed/CoordinateMatrixSuite.scala | 5 +++++ .../distributed/IndexedRowMatrixSuite.scala | 8 ++++++++ 4 files changed, 35 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 06d8915f3bfa1..b60559c853a50 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -69,6 +69,11 @@ class CoordinateMatrix( nRows } + /** Transposes this CoordinateMatrix. */ + def transpose(): CoordinateMatrix = { + new CoordinateMatrix(entries.map(x => MatrixEntry(x.j, x.i, x.value)), numCols(), numRows()) + } + /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ def toIndexedRowMatrix(): IndexedRowMatrix = { val nl = numCols() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 181f507516485..c518271f04729 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -75,6 +75,23 @@ class IndexedRowMatrix( new RowMatrix(rows.map(_.vector), 0L, nCols) } + /** + * Converts this matrix to a + * [[org.apache.spark.mllib.linalg.distributed.CoordinateMatrix]]. + */ + def toCoordinateMatrix(): CoordinateMatrix = { + val entries = rows.flatMap { row => + val rowIndex = row.index + row.vector match { + case SparseVector(size, indices, values) => + Iterator.tabulate(indices.size)(i => MatrixEntry(rowIndex, indices(i), values(i))) + case DenseVector(values) => + Iterator.tabulate(values.size)(i => MatrixEntry(rowIndex, i, values(i))) + } + } + new CoordinateMatrix(entries, numRows(), numCols()) + } + /** * Computes the singular value decomposition of this IndexedRowMatrix. * Denote this matrix by A (m x n), this will compute matrices U, S, V such that A = U * S * V'. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index f8709751efce6..80bef814ce50d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -73,6 +73,11 @@ class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(mat.toBreeze() === expected) } + test("transpose") { + val transposed = mat.transpose() + assert(mat.toBreeze().t === transposed.toBreeze()) + } + test("toIndexedRowMatrix") { val indexedRowMatrix = mat.toIndexedRowMatrix() val expected = BDM( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 741cd4997b853..b86c2ca5ff136 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -80,6 +80,14 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(rowMat.rows.collect().toSeq === data.map(_.vector).toSeq) } + test("toCoordinateMatrix") { + val idxRowMat = new IndexedRowMatrix(indexedRows) + val coordMat = idxRowMat.toCoordinateMatrix() + assert(coordMat.numRows() === m) + assert(coordMat.numCols() === n) + assert(coordMat.toBreeze() === idxRowMat.toBreeze()) + } + test("multiply a local matrix") { val A = new IndexedRowMatrix(indexedRows) val B = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) From 7450a992b3b543a373c34fc4444a528954ac4b4a Mon Sep 17 00:00:00 2001 From: "nate.crosswhite" Date: Wed, 21 Jan 2015 10:32:10 -0800 Subject: [PATCH 14/27] [SPARK-4749] [mllib]: Allow initializing KMeans clusters using a seed This implements the functionality for SPARK-4749 and provides units tests in Scala and PySpark Author: nate.crosswhite Author: nxwhite-str Author: Xiangrui Meng Closes #3610 from nxwhite-str/master and squashes the following commits: a2ebbd3 [nxwhite-str] Merge pull request #1 from mengxr/SPARK-4749-kmeans-seed 7668124 [Xiangrui Meng] minor updates f8d5928 [nate.crosswhite] Addressing PR issues 277d367 [nate.crosswhite] Merge remote-tracking branch 'upstream/master' 9156a57 [nate.crosswhite] Merge remote-tracking branch 'upstream/master' 5d087b4 [nate.crosswhite] Adding KMeans train with seed and Scala unit test 616d111 [nate.crosswhite] Merge remote-tracking branch 'upstream/master' 35c1884 [nate.crosswhite] Add kmeans initial seed to pyspark API --- .../mllib/api/python/PythonMLLibAPI.scala | 6 ++- .../spark/mllib/clustering/KMeans.scala | 48 +++++++++++++++---- .../spark/mllib/clustering/KMeansSuite.scala | 21 ++++++++ python/pyspark/mllib/clustering.py | 4 +- python/pyspark/mllib/tests.py | 17 ++++++- 5 files changed, 84 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 555da8c7e7ab3..430d763ef7ca7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -266,12 +266,16 @@ class PythonMLLibAPI extends Serializable { k: Int, maxIterations: Int, runs: Int, - initializationMode: String): KMeansModel = { + initializationMode: String, + seed: java.lang.Long): KMeansModel = { val kMeansAlg = new KMeans() .setK(k) .setMaxIterations(maxIterations) .setRuns(runs) .setInitializationMode(initializationMode) + + if (seed != null) kMeansAlg.setSeed(seed) + try { kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) } finally { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 54c301d3e9e14..6b5c934f015ba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -19,14 +19,14 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer -import org.apache.spark.annotation.Experimental import org.apache.spark.Logging -import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom /** @@ -43,13 +43,14 @@ class KMeans private ( private var runs: Int, private var initializationMode: String, private var initializationSteps: Int, - private var epsilon: Double) extends Serializable with Logging { + private var epsilon: Double, + private var seed: Long) extends Serializable with Logging { /** * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1, - * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4}. + * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}. */ - def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4) + def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong()) /** Set the number of clusters to create (k). Default: 2. */ def setK(k: Int): this.type = { @@ -112,6 +113,12 @@ class KMeans private ( this } + /** Set the random seed for cluster initialization. */ + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. @@ -255,7 +262,7 @@ class KMeans private ( private def initRandom(data: RDD[VectorWithNorm]) : Array[Array[VectorWithNorm]] = { // Sample all the cluster centers in one pass to avoid repeated scans - val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq + val sample = data.takeSample(true, runs * k, new XORShiftRandom(this.seed).nextInt()).toSeq Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v => new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm) }.toArray) @@ -273,7 +280,7 @@ class KMeans private ( private def initKMeansParallel(data: RDD[VectorWithNorm]) : Array[Array[VectorWithNorm]] = { // Initialize each run's center to a random point - val seed = new XORShiftRandom().nextInt() + val seed = new XORShiftRandom(this.seed).nextInt() val sample = data.takeSample(true, runs, seed).toSeq val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) @@ -333,7 +340,32 @@ object KMeans { /** * Trains a k-means model using the given set of parameters. * - * @param data training points stored as `RDD[Array[Double]]` + * @param data training points stored as `RDD[Vector]` + * @param k number of clusters + * @param maxIterations max number of iterations + * @param runs number of parallel runs, defaults to 1. The best model is returned. + * @param initializationMode initialization model, either "random" or "k-means||" (default). + * @param seed random seed value for cluster initialization + */ + def train( + data: RDD[Vector], + k: Int, + maxIterations: Int, + runs: Int, + initializationMode: String, + seed: Long): KMeansModel = { + new KMeans().setK(k) + .setMaxIterations(maxIterations) + .setRuns(runs) + .setInitializationMode(initializationMode) + .setSeed(seed) + .run(data) + } + + /** + * Trains a k-means model using the given set of parameters. + * + * @param data training points stored as `RDD[Vector]` * @param k number of clusters * @param maxIterations max number of iterations * @param runs number of parallel runs, defaults to 1. The best model is returned. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 9ebef8466c831..caee5917000aa 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -90,6 +90,27 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { assert(model.clusterCenters.size === 3) } + test("deterministic initialization") { + // Create a large-ish set of points for clustering + val points = List.tabulate(1000)(n => Vectors.dense(n, n)) + val rdd = sc.parallelize(points, 3) + + for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) { + // Create three deterministic models and compare cluster means + val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, + initializationMode = initMode, seed = 42) + val centers1 = model1.clusterCenters + + val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, + initializationMode = initMode, seed = 42) + val centers2 = model2.clusterCenters + + centers1.zip(centers2).foreach { case (c1, c2) => + assert(c1 ~== c2 absTol 1E-14) + } + } + } + test("single cluster with big dataset") { val smallData = Array( Vectors.dense(1.0, 2.0, 6.0), diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index e2492eef5bd6a..6b713aa39374e 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -78,10 +78,10 @@ def predict(self, x): class KMeans(object): @classmethod - def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"): + def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None): """Train a k-means clustering model.""" model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations, - runs, initializationMode) + runs, initializationMode, seed) centers = callJavaFunc(rdd.context, model.clusterCenters) return KMeansModel([c.toArray() for c in centers]) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 140c22b5fd4e8..f48e3d6dacb4b 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -140,7 +140,7 @@ class ListTests(PySparkTestCase): as NumPy arrays. """ - def test_clustering(self): + def test_kmeans(self): from pyspark.mllib.clustering import KMeans data = [ [0, 1.1], @@ -152,6 +152,21 @@ def test_clustering(self): self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + def test_kmeans_deterministic(self): + from pyspark.mllib.clustering import KMeans + X = range(0, 100, 10) + Y = range(0, 100, 10) + data = [[x, y] for x, y in zip(X, Y)] + clusters1 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", seed=42) + clusters2 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", seed=42) + centers1 = clusters1.centers + centers2 = clusters2.centers + for c1, c2 in zip(centers1, centers2): + # TODO: Allow small numeric difference. + self.assertTrue(array_equal(c1, c2)) + def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.tree import DecisionTree From 3ee3ab592eee831d759c940eb68231817ad6d083 Mon Sep 17 00:00:00 2001 From: Kenji Kikushima Date: Wed, 21 Jan 2015 12:34:00 -0800 Subject: [PATCH 15/27] [SPARK-5064][GraphX] Add numEdges upperbound validation for R-MAT graph generator to prevent infinite loop I looked into GraphGenerators#chooseCell, and found that chooseCell can't generate more edges than pow(2, (2 * (log2(numVertices)-1))) to make a Power-law graph. (Ex. numVertices:4 upperbound:4, numVertices:8 upperbound:16, numVertices:16 upperbound:64) If we request more edges over the upperbound, rmatGraph fall into infinite loop. So, how about adding an argument validation? Author: Kenji Kikushima Closes #3950 from kj-ki/SPARK-5064 and squashes the following commits: 4ee18c7 [Ankur Dave] Reword error message and add unit test d760bc7 [Kenji Kikushima] Add numEdges upperbound validation for R-MAT graph generator to prevent infinite loop. --- .../org/apache/spark/graphx/util/GraphGenerators.scala | 6 ++++++ .../spark/graphx/util/GraphGeneratorsSuite.scala | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 8a13c74221546..2d6a825b61726 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -133,6 +133,12 @@ object GraphGenerators { // This ensures that the 4 quadrants are the same size at all recursion levels val numVertices = math.round( math.pow(2.0, math.ceil(math.log(requestedNumVertices) / math.log(2.0)))).toInt + val numEdgesUpperBound = + math.pow(2.0, 2 * ((math.log(numVertices) / math.log(2.0)) - 1)).toInt + if (numEdgesUpperBound < numEdges) { + throw new IllegalArgumentException( + s"numEdges must be <= $numEdgesUpperBound but was $numEdges") + } var edges: Set[Edge[Int]] = Set() while (edges.size < numEdges) { if (edges.size % 100 == 0) { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala index 3abefbe52fa8a..8d9c8ddccbb3c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala @@ -110,4 +110,14 @@ class GraphGeneratorsSuite extends FunSuite with LocalSparkContext { } } + test("SPARK-5064 GraphGenerators.rmatGraph numEdges upper bound") { + withSpark { sc => + val g1 = GraphGenerators.rmatGraph(sc, 4, 4) + assert(g1.edges.count() === 4) + intercept[IllegalArgumentException] { + val g2 = GraphGenerators.rmatGraph(sc, 4, 8) + } + } + } + } From 812d3679f5f97df7b667cbc3365a49866ebc02d5 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 21 Jan 2015 12:59:41 -0800 Subject: [PATCH 16/27] [SPARK-5244] [SQL] add coalesce() in sql parser Author: Daoyuan Wang Closes #4040 from adrian-wang/coalesce and squashes the following commits: 0ac8e8f [Daoyuan Wang] add coalesce() in sql parser --- .../scala/org/apache/spark/sql/catalyst/SqlParser.scala | 2 ++ .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 0b36d8b9bfce5..388e2f74a0ecb 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -51,6 +51,7 @@ class SqlParser extends AbstractSparkSQLParser { protected val CACHE = Keyword("CACHE") protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") + protected val COALESCE = Keyword("COALESCE") protected val COUNT = Keyword("COUNT") protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") @@ -306,6 +307,7 @@ class SqlParser extends AbstractSparkSQLParser { { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ { case s ~ p ~ l => Substring(s, p, l) } + | COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) } | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 03b44ca1d6695..64648bad385e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -86,6 +86,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } + test("Add Parser of SQL COALESCE()") { + checkAnswer( + sql("""SELECT COALESCE(1, 2)"""), + 1) + checkAnswer( + sql("SELECT COALESCE(null, null, null)"), + null) + } + test("SPARK-3176 Added Parser of SQL LAST()") { checkAnswer( sql("SELECT LAST(n) FROM lowerCaseData"), From 8361078efae7d79742d6be94cf5a15637ec860dd Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 21 Jan 2015 13:05:56 -0800 Subject: [PATCH 17/27] [SPARK-5009] [SQL] Long keyword support in SQL Parsers * The `SqlLexical.allCaseVersions` will cause `StackOverflowException` if the key word is too long, the patch will fix that by normalizing all of the keywords in `SqlLexical`. * And make a unified SparkSQLParser for sharing the common code. Author: Cheng Hao Closes #3926 from chenghao-intel/long_keyword and squashes the following commits: 686660f [Cheng Hao] Support Long Keyword and Refactor the SQLParsers --- .../sql/catalyst/AbstractSparkSQLParser.scala | 59 +++++++++++++----- .../apache/spark/sql/catalyst/SqlParser.scala | 15 +---- .../spark/sql/catalyst/SqlParserSuite.scala | 61 +++++++++++++++++++ .../org/apache/spark/sql/SQLContext.scala | 2 +- .../org/apache/spark/sql/SparkSQLParser.scala | 15 +---- .../org/apache/spark/sql/sources/ddl.scala | 39 +++++------- .../spark/sql/hive/ExtendedHiveQlParser.scala | 16 +---- .../apache/spark/sql/hive/HiveContext.scala | 2 +- 8 files changed, 128 insertions(+), 81 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 93d74adbcc957..366be00473d1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -25,15 +25,42 @@ import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.plans.logical._ +private[sql] object KeywordNormalizer { + def apply(str: String) = str.toLowerCase() +} + private[sql] abstract class AbstractSparkSQLParser extends StandardTokenParsers with PackratParsers { - def apply(input: String): LogicalPlan = phrase(start)(new lexical.Scanner(input)) match { - case Success(plan, _) => plan - case failureOrError => sys.error(failureOrError.toString) + def apply(input: String): LogicalPlan = { + // Initialize the Keywords. + lexical.initialize(reservedWords) + phrase(start)(new lexical.Scanner(input)) match { + case Success(plan, _) => plan + case failureOrError => sys.error(failureOrError.toString) + } } - protected case class Keyword(str: String) + protected case class Keyword(str: String) { + def normalize = KeywordNormalizer(str) + def parser: Parser[String] = normalize + } + + protected implicit def asParser(k: Keyword): Parser[String] = k.parser + + // By default, use Reflection to find the reserved words defined in the sub class. + // NOTICE, Since the Keyword properties defined by sub class, we couldn't call this + // method during the parent class instantiation, because the sub class instance + // isn't created yet. + protected lazy val reservedWords: Seq[String] = + this + .getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].normalize) + + // Set the keywords as empty by default, will change that later. + override val lexical = new SqlLexical protected def start: Parser[LogicalPlan] @@ -52,18 +79,27 @@ private[sql] abstract class AbstractSparkSQLParser } } -class SqlLexical(val keywords: Seq[String]) extends StdLexical { +class SqlLexical extends StdLexical { case class FloatLit(chars: String) extends Token { override def toString = chars } - reserved ++= keywords.flatMap(w => allCaseVersions(w)) + /* This is a work around to support the lazy setting */ + def initialize(keywords: Seq[String]): Unit = { + reserved.clear() + reserved ++= keywords + } delimiters += ( "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>" ) + protected override def processIdent(name: String) = { + val token = KeywordNormalizer(name) + if (reserved contains token) Keyword(token) else Identifier(name) + } + override lazy val token: Parser[Token] = ( identChar ~ (identChar | digit).* ^^ { case first ~ rest => processIdent((first :: rest).mkString) } @@ -94,14 +130,5 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { | '-' ~ '-' ~ chrExcept(EofCh, '\n').* | '/' ~ '*' ~ failure("unclosed comment") ).* - - /** Generate all variations of upper and lower case of a given string */ - def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { - if (s.isEmpty) { - Stream(prefix) - } else { - allCaseVersions(s.tail, prefix + s.head.toLower) #::: - allCaseVersions(s.tail, prefix + s.head.toUpper) - } - } } + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 388e2f74a0ecb..4ca4e05edd460 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -36,9 +36,8 @@ import org.apache.spark.sql.types._ * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ class SqlParser extends AbstractSparkSQLParser { - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` + // properties via reflection the class in runtime for constructing the SqlLexical object protected val ABS = Keyword("ABS") protected val ALL = Keyword("ALL") protected val AND = Keyword("AND") @@ -108,16 +107,6 @@ class SqlParser extends AbstractSparkSQLParser { protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") - // Use reflection to find the reserved words defined in this class. - protected val reservedWords = - this - .getClass - .getMethods - .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].str) - - override val lexical = new SqlLexical(reservedWords) - protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { exprs.zipWithIndex.map { case (ne: NamedExpression, _) => ne diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala new file mode 100644 index 0000000000000..1a0a0e6154ad2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.Command +import org.scalatest.FunSuite + +private[sql] case class TestCommand(cmd: String) extends Command + +private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { + protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") + + override protected lazy val start: Parser[LogicalPlan] = set + + private lazy val set: Parser[LogicalPlan] = + EXECUTE ~> ident ^^ { + case fileName => TestCommand(fileName) + } +} + +private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { + protected val EXECUTE = Keyword("EXECUTE") + + override protected lazy val start: Parser[LogicalPlan] = set + + private lazy val set: Parser[LogicalPlan] = + EXECUTE ~> ident ^^ { + case fileName => TestCommand(fileName) + } +} + +class SqlParserSuite extends FunSuite { + + test("test long keyword") { + val parser = new SuperLongKeywordTestParser + assert(TestCommand("NotRealCommand") === parser("ThisIsASuperLongKeyWordTest NotRealCommand")) + } + + test("test case insensitive") { + val parser = new CaseInsensitiveTestParser + assert(TestCommand("NotRealCommand") === parser("EXECUTE NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser("execute NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser("exEcute NotRealCommand")) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f23cb18c92d5d..0a22968cc7807 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -107,7 +107,7 @@ class SQLContext(@transient val sparkContext: SparkContext) } protected[sql] def parseSql(sql: String): LogicalPlan = { - ddlParser(sql).getOrElse(sqlParser(sql)) + ddlParser(sql, false).getOrElse(sqlParser(sql)) } protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala index f10ee7b66feb7..f1a4053b79113 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql + import scala.util.parsing.combinator.RegexParsers -import org.apache.spark.sql.catalyst.{SqlLexical, AbstractSparkSQLParser} +import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{UncacheTableCommand, CacheTableCommand, SetCommand} @@ -61,18 +62,6 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr protected val TABLE = Keyword("TABLE") protected val UNCACHE = Keyword("UNCACHE") - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - - private val reservedWords: Seq[String] = - this - .getClass - .getMethods - .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].str) - - override val lexical = new SqlLexical(reservedWords) - override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | others private lazy val cache: Parser[LogicalPlan] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 381298caba6f2..171b816a26332 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -18,32 +18,32 @@ package org.apache.spark.sql.sources import scala.language.implicitConversions -import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.PackratParsers import org.apache.spark.Logging import org.apache.spark.sql.{SchemaRDD, SQLContext} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.SqlLexical +import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types._ import org.apache.spark.util.Utils + /** * A parser for foreign DDL commands. */ -private[sql] class DDLParser extends StandardTokenParsers with PackratParsers with Logging { - - def apply(input: String): Option[LogicalPlan] = { - phrase(ddl)(new lexical.Scanner(input)) match { - case Success(r, x) => Some(r) - case x => - logDebug(s"Not recognized as DDL: $x") - None +private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { + + def apply(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = { + try { + Some(apply(input)) + } catch { + case _ if !exceptionOnError => None + case x: Throwable => throw x } } def parseType(input: String): DataType = { + lexical.initialize(reservedWords) phrase(dataType)(new lexical.Scanner(input)) match { case Success(r, x) => r case x => @@ -51,11 +51,9 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi } } - protected case class Keyword(str: String) - - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` + // properties via reflection the class in runtime for constructing the SqlLexical object protected val CREATE = Keyword("CREATE") protected val TEMPORARY = Keyword("TEMPORARY") protected val TABLE = Keyword("TABLE") @@ -80,17 +78,10 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected val MAP = Keyword("MAP") protected val STRUCT = Keyword("STRUCT") - // Use reflection to find the reserved words defined in this class. - protected val reservedWords = - this.getClass - .getMethods - .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].str) - - override val lexical = new SqlLexical(reservedWords) - protected lazy val ddl: Parser[LogicalPlan] = createTable + protected def start: Parser[LogicalPlan] = ddl + /** * `CREATE [TEMPORARY] TABLE avroTable * USING org.apache.spark.sql.avro diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index ebf7003ff9e57..3f20c6142e59a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -20,30 +20,20 @@ package org.apache.spark.sql.hive import scala.language.implicitConversions import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, SqlLexical} +import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.hive.execution.{AddJar, AddFile, HiveNativeCommand} /** * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions. */ private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` + // properties via reflection the class in runtime for constructing the SqlLexical object protected val ADD = Keyword("ADD") protected val DFS = Keyword("DFS") protected val FILE = Keyword("FILE") protected val JAR = Keyword("JAR") - private val reservedWords = - this - .getClass - .getMethods - .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].str) - - override val lexical = new SqlLexical(reservedWords) - protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl protected lazy val hiveQl: Parser[LogicalPlan] = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 3e26fe3675768..274f83af5ac03 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -70,7 +70,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { if (conf.dialect == "sql") { super.sql(sqlText) } else if (conf.dialect == "hiveql") { - new SchemaRDD(this, ddlParser(sqlText).getOrElse(HiveQl.parseSql(sqlText))) + new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(sqlText))) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") } From b328ac6c8c489ef9abf850c45db5ad531da18d55 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 21 Jan 2015 14:27:43 -0800 Subject: [PATCH 18/27] Revert "[SPARK-5244] [SQL] add coalesce() in sql parser" This reverts commit 812d3679f5f97df7b667cbc3365a49866ebc02d5. --- .../scala/org/apache/spark/sql/catalyst/SqlParser.scala | 2 -- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 9 --------- 2 files changed, 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 4ca4e05edd460..eaadbe9fd5099 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -50,7 +50,6 @@ class SqlParser extends AbstractSparkSQLParser { protected val CACHE = Keyword("CACHE") protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") - protected val COALESCE = Keyword("COALESCE") protected val COUNT = Keyword("COUNT") protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") @@ -296,7 +295,6 @@ class SqlParser extends AbstractSparkSQLParser { { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ { case s ~ p ~ l => Substring(s, p, l) } - | COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) } | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 64648bad385e7..03b44ca1d6695 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -86,15 +86,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } - test("Add Parser of SQL COALESCE()") { - checkAnswer( - sql("""SELECT COALESCE(1, 2)"""), - 1) - checkAnswer( - sql("SELECT COALESCE(null, null, null)"), - null) - } - test("SPARK-3176 Added Parser of SQL LAST()") { checkAnswer( sql("SELECT LAST(n) FROM lowerCaseData"), From ba19689fe77b90052b587640c9ff325c5a892c20 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 21 Jan 2015 14:38:10 -0800 Subject: [PATCH 19/27] [SQL] [Minor] Remove deprecated parquet tests This PR removes the deprecated `ParquetQuerySuite`, renamed `ParquetQuerySuite2` to `ParquetQuerySuite`, and refactored changes introduced in #4115 to `ParquetFilterSuite` . It is a follow-up of #3644. Notice that test cases in the old `ParquetQuerySuite` have already been well covered by other test suites introduced in #3644. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/4116) Author: Cheng Lian Closes #4116 from liancheng/remove-deprecated-parquet-tests and squashes the following commits: f73b8f9 [Cheng Lian] Removes deprecated Parquet test suite --- .../sql/parquet/ParquetFilterSuite.scala | 373 +++--- .../spark/sql/parquet/ParquetQuerySuite.scala | 1040 +---------------- .../sql/parquet/ParquetQuerySuite2.scala | 88 -- 3 files changed, 212 insertions(+), 1289 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 4ad8c472007fc..1e7d3e06fc196 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -21,7 +21,7 @@ import parquet.filter2.predicate.Operators._ import parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Literal, Predicate, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, Predicate, Row} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} @@ -40,15 +40,16 @@ import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} class ParquetFilterSuite extends QueryTest with ParquetTest { val sqlContext = TestSQLContext - private def checkFilterPushdown( + private def checkFilterPredicate( rdd: SchemaRDD, - output: Seq[Symbol], predicate: Predicate, filterClass: Class[_ <: FilterPredicate], - checker: (SchemaRDD, Any) => Unit, - expectedResult: Any): Unit = { + checker: (SchemaRDD, Seq[Row]) => Unit, + expected: Seq[Row]): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { - val query = rdd.select(output.map(_.attr): _*).where(predicate) + val query = rdd.select(output: _*).where(predicate) val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect { case plan: ParquetTableScan => plan.columnPruningPred @@ -58,209 +59,180 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { maybeAnalyzedPredicate.foreach { pred => val maybeFilter = ParquetFilters.createFilter(pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") - maybeFilter.foreach(f => assert(f.getClass === filterClass)) + maybeFilter.foreach { f => + // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) + assert(f.getClass === filterClass) + } } - checker(query, expectedResult) + checker(query, expected) } } - private def checkFilterPushdown1 - (rdd: SchemaRDD, output: Symbol*) - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate]) - (expectedResult: => Seq[Row]): Unit = { - checkFilterPushdown(rdd, output, predicate, filterClass, - (query, expected) => checkAnswer(query, expected.asInstanceOf[Seq[Row]]), expectedResult) - } - - private def checkFilterPushdown - (rdd: SchemaRDD, output: Symbol*) - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate]) - (expectedResult: Int): Unit = { - checkFilterPushdown(rdd, output, predicate, filterClass, - (query, expected) => checkAnswer(query, expected.asInstanceOf[Seq[Row]]), Seq(Row(expectedResult))) + private def checkFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) + (implicit rdd: SchemaRDD): Unit = { + checkFilterPredicate(rdd, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected) } - def checkBinaryFilterPushdown - (rdd: SchemaRDD, output: Symbol*) - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate]) - (expectedResult: => Any): Unit = { - def checkBinaryAnswer(rdd: SchemaRDD, result: Any): Unit = { - val actual = rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq - val expected = result match { - case s: Seq[_] => s.map(_.asInstanceOf[Row].getAs[Array[Byte]](0).mkString(",")) - case s => Seq(s.asInstanceOf[Array[Byte]].mkString(",")) - } - assert(actual.sorted === expected.sorted) - } - checkFilterPushdown(rdd, output, predicate, filterClass, checkBinaryAnswer _, expectedResult) + private def checkFilterPredicate[T] + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: T) + (implicit rdd: SchemaRDD): Unit = { + checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } test("filter pushdown - boolean") { - withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { rdd => - checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Boolean]])(Seq.empty[Row]) - checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Boolean]]) { - Seq(Row(true), Row(false)) - } + withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) - checkFilterPushdown1(rdd, '_1)('_1 === true, classOf[Eq[java.lang.Boolean]])(Seq(Row(true))) - checkFilterPushdown1(rdd, '_1)('_1 !== true, classOf[Operators.NotEq[java.lang.Boolean]])(Seq(Row(false))) + checkFilterPredicate('_1 === true, classOf[Eq [_]], true) + checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) } } test("filter pushdown - integer") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { rdd => - checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[Integer]])(Seq.empty[Row]) - checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[Integer]]) { - (1 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[Integer]])(1) - checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[Integer]]) { - (2 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 < 2, classOf[Lt [Integer]])(1) - checkFilterPushdown(rdd, '_1)('_1 > 3, classOf[Gt [Integer]])(4) - checkFilterPushdown(rdd, '_1)('_1 <= 1, classOf[LtEq[Integer]])(1) - checkFilterPushdown(rdd, '_1)('_1 >= 4, classOf[GtEq[Integer]])(4) - - checkFilterPushdown(rdd, '_1)(Literal(1) === '_1, classOf[Eq [Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(2) > '_1, classOf[Lt [Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(3) < '_1, classOf[Gt [Integer]])(4) - checkFilterPushdown(rdd, '_1)(Literal(1) >= '_1, classOf[LtEq[Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(4) <= '_1, classOf[GtEq[Integer]])(4) - - checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[Integer]])(4) - checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { - Seq(Row(1), Row(4)) - } + withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - long") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { rdd => - checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Long]])(Seq.empty[Row]) - checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Long]]) { - (1 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Long]])(1) - checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Long]]) { - (2 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 < 2, classOf[Lt [java.lang.Long]])(1) - checkFilterPushdown(rdd, '_1)('_1 > 3, classOf[Gt [java.lang.Long]])(4) - checkFilterPushdown(rdd, '_1)('_1 <= 1, classOf[LtEq[java.lang.Long]])(1) - checkFilterPushdown(rdd, '_1)('_1 >= 4, classOf[GtEq[java.lang.Long]])(4) - - checkFilterPushdown(rdd, '_1)(Literal(1) === '_1, classOf[Eq [Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(2) > '_1, classOf[Lt [java.lang.Long]])(1) - checkFilterPushdown(rdd, '_1)(Literal(3) < '_1, classOf[Gt [java.lang.Long]])(4) - checkFilterPushdown(rdd, '_1)(Literal(1) >= '_1, classOf[LtEq[java.lang.Long]])(1) - checkFilterPushdown(rdd, '_1)(Literal(4) <= '_1, classOf[GtEq[java.lang.Long]])(4) - - checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Long]])(4) - checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { - Seq(Row(1), Row(4)) - } + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - float") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { rdd => - checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Float]])(Seq.empty[Row]) - checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Float]]) { - (1 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Float]])(1) - checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Float]]) { - (2 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 < 2, classOf[Lt [java.lang.Float]])(1) - checkFilterPushdown(rdd, '_1)('_1 > 3, classOf[Gt [java.lang.Float]])(4) - checkFilterPushdown(rdd, '_1)('_1 <= 1, classOf[LtEq[java.lang.Float]])(1) - checkFilterPushdown(rdd, '_1)('_1 >= 4, classOf[GtEq[java.lang.Float]])(4) - - checkFilterPushdown(rdd, '_1)(Literal(1) === '_1, classOf[Eq [Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(2) > '_1, classOf[Lt [java.lang.Float]])(1) - checkFilterPushdown(rdd, '_1)(Literal(3) < '_1, classOf[Gt [java.lang.Float]])(4) - checkFilterPushdown(rdd, '_1)(Literal(1) >= '_1, classOf[LtEq[java.lang.Float]])(1) - checkFilterPushdown(rdd, '_1)(Literal(4) <= '_1, classOf[GtEq[java.lang.Float]])(4) - - checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Float]])(4) - checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { - Seq(Row(1), Row(4)) - } + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - double") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { rdd => - checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Double]])(Seq.empty[Row]) - checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Double]]) { - (1 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Double]])(1) - checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Double]]) { - (2 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 < 2, classOf[Lt [java.lang.Double]])(1) - checkFilterPushdown(rdd, '_1)('_1 > 3, classOf[Gt [java.lang.Double]])(4) - checkFilterPushdown(rdd, '_1)('_1 <= 1, classOf[LtEq[java.lang.Double]])(1) - checkFilterPushdown(rdd, '_1)('_1 >= 4, classOf[GtEq[java.lang.Double]])(4) - - checkFilterPushdown(rdd, '_1)(Literal(1) === '_1, classOf[Eq[Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(2) > '_1, classOf[Lt [java.lang.Double]])(1) - checkFilterPushdown(rdd, '_1)(Literal(3) < '_1, classOf[Gt [java.lang.Double]])(4) - checkFilterPushdown(rdd, '_1)(Literal(1) >= '_1, classOf[LtEq[java.lang.Double]])(1) - checkFilterPushdown(rdd, '_1)(Literal(4) <= '_1, classOf[GtEq[java.lang.Double]])(4) - - checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Double]])(4) - checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { - Seq(Row(1), Row(4)) - } + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - string") { - withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { rdd => - checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row]) - checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) { - (1 to 4).map(i => Row.apply(i.toString)) - } - - checkFilterPushdown1(rdd, '_1)('_1 === "1", classOf[Eq[String]])(Seq(Row("1"))) - checkFilterPushdown1(rdd, '_1)('_1 !== "1", classOf[Operators.NotEq[String]]) { - (2 to 4).map(i => Row.apply(i.toString)) - } + withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate( + '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) + + checkFilterPredicate('_1 === "1", classOf[Eq [_]], "1") + checkFilterPredicate('_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) + + checkFilterPredicate('_1 < "2", classOf[Lt [_]], "1") + checkFilterPredicate('_1 > "3", classOf[Gt [_]], "4") + checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") + checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") + + checkFilterPredicate(Literal("1") === '_1, classOf[Eq [_]], "1") + checkFilterPredicate(Literal("2") > '_1, classOf[Lt [_]], "1") + checkFilterPredicate(Literal("3") < '_1, classOf[Gt [_]], "4") + checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") + checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") + + checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") + checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") + checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) + } + } - checkFilterPushdown1(rdd, '_1)('_1 < "2", classOf[Lt [java.lang.String]])(Seq(Row("1"))) - checkFilterPushdown1(rdd, '_1)('_1 > "3", classOf[Gt [java.lang.String]])(Seq(Row("4"))) - checkFilterPushdown1(rdd, '_1)('_1 <= "1", classOf[LtEq[java.lang.String]])(Seq(Row("1"))) - checkFilterPushdown1(rdd, '_1)('_1 >= "4", classOf[GtEq[java.lang.String]])(Seq(Row("4"))) - - checkFilterPushdown1(rdd, '_1)(Literal("1") === '_1, classOf[Eq [java.lang.String]])(Seq(Row("1"))) - checkFilterPushdown1(rdd, '_1)(Literal("2") > '_1, classOf[Lt [java.lang.String]])(Seq(Row("1"))) - checkFilterPushdown1(rdd, '_1)(Literal("3") < '_1, classOf[Gt [java.lang.String]])(Seq(Row("4"))) - checkFilterPushdown1(rdd, '_1)(Literal("1") >= '_1, classOf[LtEq[java.lang.String]])(Seq(Row("1"))) - checkFilterPushdown1(rdd, '_1)(Literal("4") <= '_1, classOf[GtEq[java.lang.String]])(Seq(Row("4"))) - - checkFilterPushdown1(rdd, '_1)(!('_1 < "4"), classOf[Operators.GtEq[java.lang.String]])(Seq(Row("4"))) - checkFilterPushdown1(rdd, '_1)('_1 > "2" && '_1 < "4", classOf[Operators.And])(Seq(Row("3"))) - checkFilterPushdown1(rdd, '_1)('_1 < "2" || '_1 > "3", classOf[Operators.Or]) { - Seq(Row("1"), Row("4")) + def checkBinaryFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) + (implicit rdd: SchemaRDD): Unit = { + def checkBinaryAnswer(rdd: SchemaRDD, expected: Seq[Row]) = { + assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { + rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted } } + + checkFilterPredicate(rdd, predicate, filterClass, checkBinaryAnswer _, expected) + } + + def checkBinaryFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) + (implicit rdd: SchemaRDD): Unit = { + checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } test("filter pushdown - binary") { @@ -268,33 +240,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { def b: Array[Byte] = int.toString.getBytes("UTF-8") } - withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { rdd => - checkBinaryFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row]) - checkBinaryFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) { - (1 to 4).map(i => Row.apply(i.b)).toSeq - } - - checkBinaryFilterPushdown(rdd, '_1)('_1 === 1.b, classOf[Eq[Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 !== 1.b, classOf[Operators.NotEq[Array[Byte]]]) { - (2 to 4).map(i => Row.apply(i.b)).toSeq - } - - checkBinaryFilterPushdown(rdd, '_1)('_1 < 2.b, classOf[Lt [Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 > 3.b, classOf[Gt [Array[Byte]]])(4.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 <= 1.b, classOf[LtEq[Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 >= 4.b, classOf[GtEq[Array[Byte]]])(4.b) - - checkBinaryFilterPushdown(rdd, '_1)(Literal(1.b) === '_1, classOf[Eq [Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)(Literal(2.b) > '_1, classOf[Lt [Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)(Literal(3.b) < '_1, classOf[Gt [Array[Byte]]])(4.b) - checkBinaryFilterPushdown(rdd, '_1)(Literal(1.b) >= '_1, classOf[LtEq[Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)(Literal(4.b) <= '_1, classOf[GtEq[Array[Byte]]])(4.b) - - checkBinaryFilterPushdown(rdd, '_1)(!('_1 < 4.b), classOf[Operators.GtEq[Array[Byte]]])(4.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 > 2.b && '_1 < 4.b, classOf[Operators.And])(3.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 < 2.b || '_1 > 3.b, classOf[Operators.Or]) { - Seq(Row(1.b), Row(4.b)) - } + withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { implicit rdd => + checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkBinaryFilterPredicate( + '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) + + checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq [_]], 1.b) + checkBinaryFilterPredicate( + '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) + + checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt [_]], 1.b) + checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt [_]], 4.b) + checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) + + checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq [_]], 1.b) + checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt [_]], 1.b) + checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt [_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) + + checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) + checkBinaryFilterPredicate( + '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 2c5345b1f9148..1263ff818ea19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,1030 +17,72 @@ package org.apache.spark.sql.parquet -import scala.reflect.ClassTag - -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.mapreduce.Job -import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} -import parquet.filter2.predicate.{FilterPredicate, Operators} -import parquet.hadoop.ParquetFileWriter -import parquet.hadoop.util.ContextUtil -import parquet.io.api.Binary - -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.{Row => _, _} -import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils - -case class TestRDDEntry(key: Int, value: String) - -case class NullReflectData( - intField: java.lang.Integer, - longField: java.lang.Long, - floatField: java.lang.Float, - doubleField: java.lang.Double, - booleanField: java.lang.Boolean) - -case class OptionalReflectData( - intField: Option[Int], - longField: Option[Long], - floatField: Option[Float], - doubleField: Option[Double], - booleanField: Option[Boolean]) - -case class Nested(i: Int, s: String) - -case class Data(array: Seq[Int], nested: Nested) - -case class AllDataTypes( - stringField: String, - intField: Int, - longField: Long, - floatField: Float, - doubleField: Double, - shortField: Short, - byteField: Byte, - booleanField: Boolean) - -case class AllDataTypesWithNonPrimitiveType( - stringField: String, - intField: Int, - longField: Long, - floatField: Float, - doubleField: Double, - shortField: Short, - byteField: Byte, - booleanField: Boolean, - array: Seq[Int], - arrayContainsNull: Seq[Option[Int]], - map: Map[Int, Long], - mapValueContainsNull: Map[Int, Option[Long]], - data: Data) - -case class BinaryData(binaryData: Array[Byte]) - -case class NumericData(i: Int, d: Double) -class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { - TestData // Load test data tables. - - private var testRDD: SchemaRDD = null - private val originalParquetFilterPushdownEnabled = TestSQLContext.conf.parquetFilterPushDown - - override def beforeAll() { - ParquetTestData.writeFile() - ParquetTestData.writeFilterFile() - ParquetTestData.writeNestedFile1() - ParquetTestData.writeNestedFile2() - ParquetTestData.writeNestedFile3() - ParquetTestData.writeNestedFile4() - ParquetTestData.writeGlobFiles() - testRDD = parquetFile(ParquetTestData.testDir.toString) - testRDD.registerTempTable("testsource") - parquetFile(ParquetTestData.testFilterDir.toString) - .registerTempTable("testfiltersource") - - setConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED, "true") - } - - override def afterAll() { - Utils.deleteRecursively(ParquetTestData.testDir) - Utils.deleteRecursively(ParquetTestData.testFilterDir) - Utils.deleteRecursively(ParquetTestData.testNestedDir1) - Utils.deleteRecursively(ParquetTestData.testNestedDir2) - Utils.deleteRecursively(ParquetTestData.testNestedDir3) - Utils.deleteRecursively(ParquetTestData.testNestedDir4) - Utils.deleteRecursively(ParquetTestData.testGlobDir) - // here we should also unregister the table?? - - setConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED, originalParquetFilterPushdownEnabled.toString) - } +/** + * A test suite that tests various Parquet queries. + */ +class ParquetQuerySuite extends QueryTest with ParquetTest { + val sqlContext = TestSQLContext - test("Read/Write All Types") { - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - val range = (0 to 255) - val data = sparkContext.parallelize(range).map { x => - parquet.AllDataTypes( - s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0) + test("simple projection") { + withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { + checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) } - - data.saveAsParquetFile(tempDir) - - checkAnswer( - parquetFile(tempDir), - data.toSchemaRDD.collect().toSeq) } - test("read/write binary data") { - // Since equality for Array[Byte] is broken we test this separately. - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil).saveAsParquetFile(tempDir) - parquetFile(tempDir) - .map(r => new String(r(0).asInstanceOf[Array[Byte]], "utf8")) - .collect().toSeq == Seq("test") - } - - ignore("Treat binary as string") { - val oldIsParquetBinaryAsString = TestSQLContext.conf.isParquetBinaryAsString - - // Create the test file. - val file = getTempFilePath("parquet") - val path = file.toString - val range = (0 to 255) - val rowRDD = TestSQLContext.sparkContext.parallelize(range) - .map(i => org.apache.spark.sql.Row(i, s"val_$i".getBytes)) - // We need to ask Parquet to store the String column as a Binary column. - val schema = StructType( - StructField("c1", IntegerType, false) :: - StructField("c2", BinaryType, false) :: Nil) - val schemaRDD1 = applySchema(rowRDD, schema) - schemaRDD1.saveAsParquetFile(path) - checkAnswer( - parquetFile(path).select('c1, 'c2.cast(StringType)), - schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq) - - setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") - parquetFile(path).printSchema() - checkAnswer( - parquetFile(path), - schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq) - - - // Set it back. - TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, oldIsParquetBinaryAsString.toString) - } - - test("Compression options for writing to a Parquetfile") { - val defaultParquetCompressionCodec = TestSQLContext.conf.parquetCompressionCodec - import scala.collection.JavaConversions._ - - val file = getTempFilePath("parquet") - val path = file.toString - val rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - - // test default compression codec - rdd.saveAsParquetFile(path) - var actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) - .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil) - - parquetFile(path).registerTempTable("tmp") - checkAnswer( - sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - Row(5, "val_5") :: - Row(7, "val_7") :: Nil) - - Utils.deleteRecursively(file) - - // test uncompressed parquet file with property value "UNCOMPRESSED" - TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "UNCOMPRESSED") - - rdd.saveAsParquetFile(path) - actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) - .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil) - - parquetFile(path).registerTempTable("tmp") - checkAnswer( - sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - Row(5, "val_5") :: - Row(7, "val_7") :: Nil) - - Utils.deleteRecursively(file) - - // test uncompressed parquet file with property value "none" - TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "none") - - rdd.saveAsParquetFile(path) - actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) - .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === "UNCOMPRESSED" :: Nil) - - parquetFile(path).registerTempTable("tmp") - checkAnswer( - sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - Row(5, "val_5") :: - Row(7, "val_7") :: Nil) - - Utils.deleteRecursively(file) - - // test gzip compression codec - TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "gzip") - - rdd.saveAsParquetFile(path) - actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) - .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil) - - parquetFile(path).registerTempTable("tmp") - checkAnswer( - sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - Row(5, "val_5") :: - Row(7, "val_7") :: Nil) - - Utils.deleteRecursively(file) - - // test snappy compression codec - TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "snappy") - - rdd.saveAsParquetFile(path) - actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) - .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil) - - parquetFile(path).registerTempTable("tmp") - checkAnswer( - sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - Row(5, "val_5") :: - Row(7, "val_7") :: Nil) - - Utils.deleteRecursively(file) - - // TODO: Lzo requires additional external setup steps so leave it out for now - // ref.: https://github.com/Parquet/parquet-mr/blob/parquet-1.5.0/parquet-hadoop/src/test/java/parquet/hadoop/example/TestInputOutputFormat.java#L169 - - // Set it back. - TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, defaultParquetCompressionCodec) - } - - test("Read/Write All Types with non-primitive type") { - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - val range = (0 to 255) - val data = sparkContext.parallelize(range).map { x => - parquet.AllDataTypesWithNonPrimitiveType( - s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - (0 until x), - (0 until x).map(Option(_).filter(_ % 3 == 0)), - (0 until x).map(i => i -> i.toLong).toMap, - (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), - parquet.Data((0 until x), parquet.Nested(x, s"$x"))) + test("appending") { + val data = (0 until 10).map(i => (i, i.toString)) + withParquetTable(data, "t") { + sql("INSERT INTO t SELECT * FROM t") + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) } - data.saveAsParquetFile(tempDir) - - checkAnswer( - parquetFile(tempDir), - data.toSchemaRDD.collect().toSeq) } - test("self-join parquet files") { - val x = ParquetTestData.testData.as('x) - val y = ParquetTestData.testData.as('y) - val query = x.join(y).where("x.myint".attr === "y.myint".attr) - - // Check to make sure that the attributes from either side of the join have unique expression - // ids. - query.queryExecution.analyzed.output.filter(_.name == "myint") match { - case Seq(i1, i2) if(i1.exprId == i2.exprId) => - fail(s"Duplicate expression IDs found in query plan: $query") - case Seq(_, _) => // All good + test("self-join") { + // 4 rows, cells of column 1 of row 2 and row 4 are null + val data = (1 to 4).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + (maybeInt, i.toString) } - val result = query.collect() - assert(result.size === 9, "self-join result has incorrect size") - assert(result(0).size === 12, "result row has incorrect size") - result.zipWithIndex.foreach { - case (row, index) => row.toSeq.zipWithIndex.foreach { - case (field, column) => assert(field != null, s"self-join contains null value in row $index field $column") - } - } - } + withParquetTable(data, "t") { + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") + val queryOutput = selfJoin.queryExecution.analyzed.output - test("Import of simple Parquet file") { - val result = parquetFile(ParquetTestData.testDir.toString).collect() - assert(result.size === 15) - result.zipWithIndex.foreach { - case (row, index) => { - val checkBoolean = - if (index % 3 == 0) - row(0) == true - else - row(0) == false - assert(checkBoolean === true, s"boolean field value in line $index did not match") - if (index % 5 == 0) assert(row(1) === 5, s"int field value in line $index did not match") - assert(row(2) === "abc", s"string field value in line $index did not match") - assert(row(3) === (index.toLong << 33), s"long value in line $index did not match") - assert(row(4) === 2.5F, s"float field value in line $index did not match") - assert(row(5) === 4.5D, s"double field value in line $index did not match") + assertResult(4, s"Field count mismatches")(queryOutput.size) + assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { + queryOutput.filter(_.name == "_1").map(_.exprId).size } - } - } - test("Projection of simple Parquet file") { - val result = ParquetTestData.testData.select('myboolean, 'mylong).collect() - result.zipWithIndex.foreach { - case (row, index) => { - if (index % 3 == 0) - assert(row(0) === true, s"boolean field value in line $index did not match (every third row)") - else - assert(row(0) === false, s"boolean field value in line $index did not match") - assert(row(1) === (index.toLong << 33), s"long field value in line $index did not match") - assert(row.size === 2, s"number of columns in projection in line $index is incorrect") - } + checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) } } - test("Writing metadata from scratch for table CREATE") { - val job = new Job() - val path = new Path(getTempFilePath("testtable").getCanonicalFile.toURI.toString) - val fs: FileSystem = FileSystem.getLocal(ContextUtil.getConfiguration(job)) - ParquetTypesConverter.writeMetaData( - ParquetTestData.testData.output, - path, - TestSQLContext.sparkContext.hadoopConfiguration) - assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val metaData = ParquetTypesConverter.readMetaData(path, Some(ContextUtil.getConfiguration(job))) - assert(metaData != null) - ParquetTestData - .testData - .parquetSchema - .checkContains(metaData.getFileMetaData.getSchema) // throws exception if incompatible - metaData - .getFileMetaData - .getSchema - .checkContains(ParquetTestData.testData.parquetSchema) // throws exception if incompatible - fs.delete(path, true) - } - - test("Creating case class RDD table") { - TestSQLContext.sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - .registerTempTable("tmp") - val rdd = sql("SELECT * FROM tmp").collect().sortBy(_.getInt(0)) - var counter = 1 - rdd.foreach { - // '===' does not like string comparison? - row: Row => { - assert(row.getString(1).equals(s"val_$counter"), s"row $counter value ${row.getString(1)} does not match val_$counter") - counter = counter + 1 - } + test("nested data - struct with array field") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { + case Tuple1((_, Seq(string))) => Row(string) + }) } } - test("Read a parquet file instead of a directory") { - val file = getTempFilePath("parquet") - val path = file.toString - val fsPath = new Path(path) - val fs: FileSystem = fsPath.getFileSystem(TestSQLContext.sparkContext.hadoopConfiguration) - val rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - rdd.coalesce(1).saveAsParquetFile(path) - - val children = fs.listStatus(fsPath).filter(_.getPath.getName.endsWith(".parquet")) - assert(children.length > 0) - val readFile = parquetFile(path + "/" + children(0).getPath.getName) - readFile.registerTempTable("tmpx") - val rdd_copy = sql("SELECT * FROM tmpx").collect() - val rdd_orig = rdd.collect() - for(i <- 0 to 99) { - assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i") - assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value error in line $i") + test("nested data - array of struct") { + val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { + case Tuple1(Seq((_, string))) => Row(string) + }) } - Utils.deleteRecursively(file) - } - - test("Insert (appending) to same table via Scala API") { - sql("INSERT INTO testsource SELECT * FROM testsource") - val double_rdd = sql("SELECT * FROM testsource").collect() - assert(double_rdd != null) - assert(double_rdd.size === 30) - - // let's restore the original test data - Utils.deleteRecursively(ParquetTestData.testDir) - ParquetTestData.writeFile() - } - - test("save and load case class RDD with nulls as parquet") { - val data = parquet.NullReflectData(null, null, null, null, null) - val rdd = sparkContext.parallelize(data :: Nil) - - val file = getTempFilePath("parquet") - val path = file.toString - rdd.saveAsParquetFile(path) - val readFile = parquetFile(path) - - val rdd_saved = readFile.collect() - assert(rdd_saved(0) === Row(null, null, null, null, null)) - Utils.deleteRecursively(file) - assert(true) - } - - test("save and load case class RDD with Nones as parquet") { - val data = parquet.OptionalReflectData(None, None, None, None, None) - val rdd = sparkContext.parallelize(data :: Nil) - - val file = getTempFilePath("parquet") - val path = file.toString - rdd.saveAsParquetFile(path) - val readFile = parquetFile(path) - - val rdd_saved = readFile.collect() - assert(rdd_saved(0) === Row(null, null, null, null, null)) - Utils.deleteRecursively(file) - assert(true) - } - - test("make RecordFilter for simple predicates") { - def checkFilter[T <: FilterPredicate : ClassTag]( - predicate: Expression, - defined: Boolean = true): Unit = { - val filter = ParquetFilters.createFilter(predicate) - if (defined) { - assert(filter.isDefined) - val tClass = implicitly[ClassTag[T]].runtimeClass - val filterGet = filter.get - assert( - tClass.isInstance(filterGet), - s"$filterGet of type ${filterGet.getClass} is not an instance of $tClass") - } else { - assert(filter.isEmpty) - } - } - - checkFilter[Operators.Eq[Integer]]('a.int === 1) - checkFilter[Operators.Eq[Integer]](Literal(1) === 'a.int) - - checkFilter[Operators.Lt[Integer]]('a.int < 4) - checkFilter[Operators.Lt[Integer]](Literal(4) > 'a.int) - checkFilter[Operators.LtEq[Integer]]('a.int <= 4) - checkFilter[Operators.LtEq[Integer]](Literal(4) >= 'a.int) - - checkFilter[Operators.Gt[Integer]]('a.int > 4) - checkFilter[Operators.Gt[Integer]](Literal(4) < 'a.int) - checkFilter[Operators.GtEq[Integer]]('a.int >= 4) - checkFilter[Operators.GtEq[Integer]](Literal(4) <= 'a.int) - - checkFilter[Operators.And]('a.int === 1 && 'a.int < 4) - checkFilter[Operators.Or]('a.int === 1 || 'a.int < 4) - checkFilter[Operators.NotEq[Integer]](!('a.int === 1)) - - checkFilter('a.int > 'b.int, defined = false) - checkFilter(('a.int > 'b.int) && ('a.int > 'b.int), defined = false) - } - - test("test filter by predicate pushdown") { - for(myval <- Seq("myint", "mylong", "mydouble", "myfloat")) { - val query1 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100") - assert( - query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result1 = query1.collect() - assert(result1.size === 50) - assert(result1(0)(1) === 100) - assert(result1(49)(1) === 149) - val query2 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") - assert( - query2.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result2 = query2.collect() - assert(result2.size === 50) - if (myval == "myint" || myval == "mylong") { - assert(result2(0)(1) === 151) - assert(result2(49)(1) === 200) - } else { - assert(result2(0)(1) === 150) - assert(result2(49)(1) === 199) - } - } - for(myval <- Seq("myint", "mylong")) { - val query3 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190 OR $myval < 10") - assert( - query3.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result3 = query3.collect() - assert(result3.size === 20) - assert(result3(0)(1) === 0) - assert(result3(9)(1) === 9) - assert(result3(10)(1) === 191) - assert(result3(19)(1) === 200) - } - for(myval <- Seq("mydouble", "myfloat")) { - val result4 = - if (myval == "mydouble") { - val query4 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10.0") - assert( - query4.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - query4.collect() - } else { - // CASTs are problematic. Here myfloat will be casted to a double and it seems there is - // currently no way to specify float constants in SqlParser? - sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10").collect() - } - assert(result4.size === 20) - assert(result4(0)(1) === 0) - assert(result4(9)(1) === 9) - assert(result4(10)(1) === 191) - assert(result4(19)(1) === 200) - } - val query5 = sql(s"SELECT * FROM testfiltersource WHERE myboolean = true AND myint < 40") - assert( - query5.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val booleanResult = query5.collect() - assert(booleanResult.size === 10) - for(i <- 0 until 10) { - if (!booleanResult(i).getBoolean(0)) { - fail(s"Boolean value in result row $i not true") - } - if (booleanResult(i).getInt(1) != i * 4) { - fail(s"Int value in result row $i should be ${4*i}") - } - } - val query6 = sql("SELECT * FROM testfiltersource WHERE mystring = \"100\"") - assert( - query6.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val stringResult = query6.collect() - assert(stringResult.size === 1) - assert(stringResult(0).getString(2) == "100", "stringvalue incorrect") - assert(stringResult(0).getInt(1) === 100) - - val query7 = sql(s"SELECT * FROM testfiltersource WHERE myoptint < 40") - assert( - query7.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val optResult = query7.collect() - assert(optResult.size === 20) - for(i <- 0 until 20) { - if (optResult(i)(7) != i * 2) { - fail(s"optional Int value in result row $i should be ${2*4*i}") - } - } - for(myval <- Seq("myoptint", "myoptlong", "myoptdouble", "myoptfloat")) { - val query8 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100") - assert( - query8.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result8 = query8.collect() - assert(result8.size === 25) - assert(result8(0)(7) === 100) - assert(result8(24)(7) === 148) - val query9 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") - assert( - query9.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result9 = query9.collect() - assert(result9.size === 25) - if (myval == "myoptint" || myval == "myoptlong") { - assert(result9(0)(7) === 152) - assert(result9(24)(7) === 200) - } else { - assert(result9(0)(7) === 150) - assert(result9(24)(7) === 198) - } - } - val query10 = sql("SELECT * FROM testfiltersource WHERE myoptstring = \"100\"") - assert( - query10.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result10 = query10.collect() - assert(result10.size === 1) - assert(result10(0).getString(8) == "100", "stringvalue incorrect") - assert(result10(0).getInt(7) === 100) - val query11 = sql(s"SELECT * FROM testfiltersource WHERE myoptboolean = true AND myoptint < 40") - assert( - query11.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result11 = query11.collect() - assert(result11.size === 7) - for(i <- 0 until 6) { - if (!result11(i).getBoolean(6)) { - fail(s"optional Boolean value in result row $i not true") - } - if (result11(i).getInt(7) != i * 6) { - fail(s"optional Int value in result row $i should be ${6*i}") - } - } - - val query12 = sql("SELECT * FROM testfiltersource WHERE mystring >= \"50\"") - assert( - query12.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result12 = query12.collect() - assert(result12.size === 54) - assert(result12(0).getString(2) == "6") - assert(result12(4).getString(2) == "50") - assert(result12(53).getString(2) == "99") - - val query13 = sql("SELECT * FROM testfiltersource WHERE mystring > \"50\"") - assert( - query13.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result13 = query13.collect() - assert(result13.size === 53) - assert(result13(0).getString(2) == "6") - assert(result13(4).getString(2) == "51") - assert(result13(52).getString(2) == "99") - - val query14 = sql("SELECT * FROM testfiltersource WHERE mystring <= \"50\"") - assert( - query14.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result14 = query14.collect() - assert(result14.size === 148) - assert(result14(0).getString(2) == "0") - assert(result14(46).getString(2) == "50") - assert(result14(147).getString(2) == "200") - - val query15 = sql("SELECT * FROM testfiltersource WHERE mystring < \"50\"") - assert( - query15.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result15 = query15.collect() - assert(result15.size === 147) - assert(result15(0).getString(2) == "0") - assert(result15(46).getString(2) == "100") - assert(result15(146).getString(2) == "200") } test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { - val query = sql(s"SELECT mystring FROM testfiltersource WHERE myint < 10") - assert(query.collect().size === 10) - } - - test("Importing nested Parquet file (Addressbook)") { - val result = TestSQLContext - .parquetFile(ParquetTestData.testNestedDir1.toString) - .toSchemaRDD - .collect() - assert(result != null) - assert(result.size === 2) - val first_record = result(0) - val second_record = result(1) - assert(first_record != null) - assert(second_record != null) - assert(first_record.size === 3) - assert(second_record(1) === null) - assert(second_record(2) === null) - assert(second_record(0) === "A. Nonymous") - assert(first_record(0) === "Julien Le Dem") - val first_owner_numbers = first_record(1) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - val first_contacts = first_record(2) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(first_owner_numbers != null) - assert(first_owner_numbers(0) === "555 123 4567") - assert(first_owner_numbers(2) === "XXX XXX XXXX") - assert(first_contacts(0) - .asInstanceOf[CatalystConverter.StructScalaType[_]].size === 2) - val first_contacts_entry_one = first_contacts(0) - .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(first_contacts_entry_one(0) === "Dmitriy Ryaboy") - assert(first_contacts_entry_one(1) === "555 987 6543") - val first_contacts_entry_two = first_contacts(1) - .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(first_contacts_entry_two(0) === "Chris Aniszczyk") - } - - test("Importing nested Parquet file (nested numbers)") { - val result = TestSQLContext - .parquetFile(ParquetTestData.testNestedDir2.toString) - .toSchemaRDD - .collect() - assert(result.size === 1, "number of top-level rows incorrect") - assert(result(0).size === 5, "number of fields in row incorrect") - assert(result(0)(0) === 1) - assert(result(0)(1) === 7) - val subresult1 = result(0)(2).asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult1.size === 3) - assert(subresult1(0) === (1.toLong << 32)) - assert(subresult1(1) === (1.toLong << 33)) - assert(subresult1(2) === (1.toLong << 34)) - val subresult2 = result(0)(3) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(subresult2.size === 2) - assert(subresult2(0) === 2.5) - assert(subresult2(1) === false) - val subresult3 = result(0)(4) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult3.size === 2) - assert(subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 2) - val subresult4 = subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult4(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) - assert(subresult4(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) - assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 1) - assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) - } - - test("Simple query on addressbook") { - val data = TestSQLContext - .parquetFile(ParquetTestData.testNestedDir1.toString) - .toSchemaRDD - val tmp = data.where('owner === "Julien Le Dem").select('owner as 'a, 'contacts as 'c).collect() - assert(tmp.size === 1) - assert(tmp(0)(0) === "Julien Le Dem") - } - - test("Projection in addressbook") { - val data = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD - data.registerTempTable("data") - val query = sql("SELECT owner, contacts[1].name FROM data") - val tmp = query.collect() - assert(tmp.size === 2) - assert(tmp(0).size === 2) - assert(tmp(0)(0) === "Julien Le Dem") - assert(tmp(0)(1) === "Chris Aniszczyk") - assert(tmp(1)(0) === "A. Nonymous") - assert(tmp(1)(1) === null) - } - - test("Simple query on nested int data") { - val data = parquetFile(ParquetTestData.testNestedDir2.toString).toSchemaRDD - data.registerTempTable("data") - val result1 = sql("SELECT entries[0].value FROM data").collect() - assert(result1.size === 1) - assert(result1(0).size === 1) - assert(result1(0)(0) === 2.5) - val result2 = sql("SELECT entries[0] FROM data").collect() - assert(result2.size === 1) - val subresult1 = result2(0)(0).asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(subresult1.size === 2) - assert(subresult1(0) === 2.5) - assert(subresult1(1) === false) - val result3 = sql("SELECT outerouter FROM data").collect() - val subresult2 = result3(0)(0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult2(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) - assert(subresult2(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) - assert(result3(0)(0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](1) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) - } - - test("nested structs") { - val data = parquetFile(ParquetTestData.testNestedDir3.toString) - .toSchemaRDD - data.registerTempTable("data") - val result1 = sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() - assert(result1.size === 1) - assert(result1(0).size === 1) - assert(result1(0)(0) === false) - val result2 = sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() - assert(result2.size === 1) - assert(result2(0).size === 1) - assert(result2(0)(0) === true) - val result3 = sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() - assert(result3.size === 1) - assert(result3(0).size === 1) - assert(result3(0)(0) === false) - } - - test("simple map") { - val data = TestSQLContext - .parquetFile(ParquetTestData.testNestedDir4.toString) - .toSchemaRDD - data.registerTempTable("mapTable") - val result1 = sql("SELECT data1 FROM mapTable").collect() - assert(result1.size === 1) - assert(result1(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, _]] - .getOrElse("key1", 0) === 1) - assert(result1(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, _]] - .getOrElse("key2", 0) === 2) - val result2 = sql("""SELECT data1["key1"] FROM mapTable""").collect() - assert(result2(0)(0) === 1) - } - - test("map with struct values") { - val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD - data.registerTempTable("mapTable") - val result1 = sql("SELECT data2 FROM mapTable").collect() - assert(result1.size === 1) - val entry1 = result1(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] - .getOrElse("seven", null) - assert(entry1 != null) - assert(entry1(0) === 42) - assert(entry1(1) === "the answer") - val entry2 = result1(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] - .getOrElse("eight", null) - assert(entry2 != null) - assert(entry2(0) === 49) - assert(entry2(1) === null) - val result2 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect() - assert(result2.size === 1) - assert(result2(0)(0) === 42.toLong) - assert(result2(0)(1) === "the answer") - } - - test("Writing out Addressbook and reading it back in") { - // TODO: find out why CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME - // has no effect in this test case - val tmpdir = Utils.createTempDir() - Utils.deleteRecursively(tmpdir) - val result = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD - result.saveAsParquetFile(tmpdir.toString) - parquetFile(tmpdir.toString) - .toSchemaRDD - .registerTempTable("tmpcopy") - val tmpdata = sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() - assert(tmpdata.size === 2) - assert(tmpdata(0).size === 2) - assert(tmpdata(0)(0) === "Julien Le Dem") - assert(tmpdata(0)(1) === "Chris Aniszczyk") - assert(tmpdata(1)(0) === "A. Nonymous") - assert(tmpdata(1)(1) === null) - Utils.deleteRecursively(tmpdir) - } - - test("Writing out Map and reading it back in") { - val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD - val tmpdir = Utils.createTempDir() - Utils.deleteRecursively(tmpdir) - data.saveAsParquetFile(tmpdir.toString) - parquetFile(tmpdir.toString) - .toSchemaRDD - .registerTempTable("tmpmapcopy") - val result1 = sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() - assert(result1.size === 1) - assert(result1(0)(0) === 2) - val result2 = sql("SELECT data2 FROM tmpmapcopy").collect() - assert(result2.size === 1) - val entry1 = result2(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] - .getOrElse("seven", null) - assert(entry1 != null) - assert(entry1(0) === 42) - assert(entry1(1) === "the answer") - val entry2 = result2(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] - .getOrElse("eight", null) - assert(entry2 != null) - assert(entry2(0) === 49) - assert(entry2(1) === null) - val result3 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect() - assert(result3.size === 1) - assert(result3(0)(0) === 42.toLong) - assert(result3(0)(1) === "the answer") - Utils.deleteRecursively(tmpdir) - } - - test("read/write fixed-length decimals") { - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - val data = sparkContext.parallelize(0 to 1000) - .map(i => NumericData(i, i / 100.0)) - .select('i, 'd cast DecimalType(precision, scale)) - data.saveAsParquetFile(tempDir) - checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) - } - - // Decimals with precision above 18 are not yet supported - intercept[RuntimeException] { - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - val data = sparkContext.parallelize(0 to 1000) - .map(i => NumericData(i, i / 100.0)) - .select('i, 'd cast DecimalType(19, 10)) - data.saveAsParquetFile(tempDir) - checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) - } - - // Unlimited-length decimals are not yet supported - intercept[RuntimeException] { - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - val data = sparkContext.parallelize(0 to 1000) - .map(i => NumericData(i, i / 100.0)) - .select('i, 'd cast DecimalType.Unlimited) - data.saveAsParquetFile(tempDir) - checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) - } - } - - def checkFilter(predicate: Predicate, filterClass: Class[_ <: FilterPredicate]): Unit = { - val filter = ParquetFilters.createFilter(predicate) - assert(filter.isDefined) - assert(filter.get.getClass == filterClass) - } - - test("Pushdown IsNull predicate") { - checkFilter('a.int.isNull, classOf[Operators.Eq[Integer]]) - checkFilter('a.long.isNull, classOf[Operators.Eq[java.lang.Long]]) - checkFilter('a.float.isNull, classOf[Operators.Eq[java.lang.Float]]) - checkFilter('a.double.isNull, classOf[Operators.Eq[java.lang.Double]]) - checkFilter('a.string.isNull, classOf[Operators.Eq[Binary]]) - checkFilter('a.binary.isNull, classOf[Operators.Eq[Binary]]) - } - - test("Pushdown IsNotNull predicate") { - checkFilter('a.int.isNotNull, classOf[Operators.NotEq[Integer]]) - checkFilter('a.long.isNotNull, classOf[Operators.NotEq[java.lang.Long]]) - checkFilter('a.float.isNotNull, classOf[Operators.NotEq[java.lang.Float]]) - checkFilter('a.double.isNotNull, classOf[Operators.NotEq[java.lang.Double]]) - checkFilter('a.string.isNotNull, classOf[Operators.NotEq[Binary]]) - checkFilter('a.binary.isNotNull, classOf[Operators.NotEq[Binary]]) - } - - test("Pushdown EqualTo predicate") { - checkFilter('a.int === 0, classOf[Operators.Eq[Integer]]) - checkFilter('a.long === 0.toLong, classOf[Operators.Eq[java.lang.Long]]) - checkFilter('a.float === 0.toFloat, classOf[Operators.Eq[java.lang.Float]]) - checkFilter('a.double === 0.toDouble, classOf[Operators.Eq[java.lang.Double]]) - checkFilter('a.string === "foo", classOf[Operators.Eq[Binary]]) - checkFilter('a.binary === "foo".getBytes, classOf[Operators.Eq[Binary]]) - } - - test("Pushdown Not(EqualTo) predicate") { - checkFilter(!('a.int === 0), classOf[Operators.NotEq[Integer]]) - checkFilter(!('a.long === 0.toLong), classOf[Operators.NotEq[java.lang.Long]]) - checkFilter(!('a.float === 0.toFloat), classOf[Operators.NotEq[java.lang.Float]]) - checkFilter(!('a.double === 0.toDouble), classOf[Operators.NotEq[java.lang.Double]]) - checkFilter(!('a.string === "foo"), classOf[Operators.NotEq[Binary]]) - checkFilter(!('a.binary === "foo".getBytes), classOf[Operators.NotEq[Binary]]) - } - - test("Pushdown LessThan predicate") { - checkFilter('a.int < 0, classOf[Operators.Lt[Integer]]) - checkFilter('a.long < 0.toLong, classOf[Operators.Lt[java.lang.Long]]) - checkFilter('a.float < 0.toFloat, classOf[Operators.Lt[java.lang.Float]]) - checkFilter('a.double < 0.toDouble, classOf[Operators.Lt[java.lang.Double]]) - checkFilter('a.string < "foo", classOf[Operators.Lt[Binary]]) - checkFilter('a.binary < "foo".getBytes, classOf[Operators.Lt[Binary]]) - } - - test("Pushdown LessThanOrEqual predicate") { - checkFilter('a.int <= 0, classOf[Operators.LtEq[Integer]]) - checkFilter('a.long <= 0.toLong, classOf[Operators.LtEq[java.lang.Long]]) - checkFilter('a.float <= 0.toFloat, classOf[Operators.LtEq[java.lang.Float]]) - checkFilter('a.double <= 0.toDouble, classOf[Operators.LtEq[java.lang.Double]]) - checkFilter('a.string <= "foo", classOf[Operators.LtEq[Binary]]) - checkFilter('a.binary <= "foo".getBytes, classOf[Operators.LtEq[Binary]]) - } - - test("Pushdown GreaterThan predicate") { - checkFilter('a.int > 0, classOf[Operators.Gt[Integer]]) - checkFilter('a.long > 0.toLong, classOf[Operators.Gt[java.lang.Long]]) - checkFilter('a.float > 0.toFloat, classOf[Operators.Gt[java.lang.Float]]) - checkFilter('a.double > 0.toDouble, classOf[Operators.Gt[java.lang.Double]]) - checkFilter('a.string > "foo", classOf[Operators.Gt[Binary]]) - checkFilter('a.binary > "foo".getBytes, classOf[Operators.Gt[Binary]]) - } - - test("Pushdown GreaterThanOrEqual predicate") { - checkFilter('a.int >= 0, classOf[Operators.GtEq[Integer]]) - checkFilter('a.long >= 0.toLong, classOf[Operators.GtEq[java.lang.Long]]) - checkFilter('a.float >= 0.toFloat, classOf[Operators.GtEq[java.lang.Float]]) - checkFilter('a.double >= 0.toDouble, classOf[Operators.GtEq[java.lang.Double]]) - checkFilter('a.string >= "foo", classOf[Operators.GtEq[Binary]]) - checkFilter('a.binary >= "foo".getBytes, classOf[Operators.GtEq[Binary]]) - } - - test("Comparison with null should not be pushed down") { - val predicates = Seq( - 'a.int === null, - !('a.int === null), - - Literal(null) === 'a.int, - !(Literal(null) === 'a.int), - - 'a.int < null, - 'a.int <= null, - 'a.int > null, - 'a.int >= null, - - Literal(null) < 'a.int, - Literal(null) <= 'a.int, - Literal(null) > 'a.int, - Literal(null) >= 'a.int - ) - - predicates.foreach { p => - assert( - ParquetFilters.createFilter(p).isEmpty, - "Comparison predicate with null shouldn't be pushed down") - } - } - - test("Import of simple Parquet files using glob wildcard pattern") { - val testGlobDir = ParquetTestData.testGlobDir.toString - val globPatterns = Array(testGlobDir + "/*/*", testGlobDir + "/spark-*/*", testGlobDir + "/?pa?k-*/*") - globPatterns.foreach { path => - val result = parquetFile(path).collect() - assert(result.size === 45) - result.zipWithIndex.foreach { - case (row, index) => { - val checkBoolean = - if ((index % 15) % 3 == 0) - row(0) == true - else - row(0) == false - assert(checkBoolean === true, s"boolean field value in line $index did not match") - if ((index % 15) % 5 == 0) assert(row(1) === 5, s"int field value in line $index did not match") - assert(row(2) === "abc", s"string field value in line $index did not match") - assert(row(3) === ((index.toLong % 15) << 33), s"long value in line $index did not match") - assert(row(4) === 2.5F, s"float field value in line $index did not match") - assert(row(5) === 4.5D, s"double field value in line $index did not match") - } - } + withParquetTable((1 to 10).map(Tuple1.apply), "t") { + checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala deleted file mode 100644 index 7b3f8c22af2db..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ - -/** - * A test suite that tests various Parquet queries. - */ -class ParquetQuerySuite2 extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext - - test("simple projection") { - withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { - checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) - } - } - - test("appending") { - val data = (0 until 10).map(i => (i, i.toString)) - withParquetTable(data, "t") { - sql("INSERT INTO t SELECT * FROM t") - checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) - } - } - - test("self-join") { - // 4 rows, cells of column 1 of row 2 and row 4 are null - val data = (1 to 4).map { i => - val maybeInt = if (i % 2 == 0) None else Some(i) - (maybeInt, i.toString) - } - - withParquetTable(data, "t") { - val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") - val queryOutput = selfJoin.queryExecution.analyzed.output - - assertResult(4, s"Field count mismatches")(queryOutput.size) - assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { - queryOutput.filter(_.name == "_1").map(_.exprId).size - } - - checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) - } - } - - test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { - case Tuple1((_, Seq(string))) => Row(string) - }) - } - } - - test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { - case Tuple1(Seq((_, string))) => Row(string) - }) - } - } - - test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { - withParquetTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) - } - } -} From 3be2a887bf6107b6398e472872b22175ea4ae1f7 Mon Sep 17 00:00:00 2001 From: wangfei Date: Wed, 21 Jan 2015 15:27:42 -0800 Subject: [PATCH 20/27] [SPARK-4984][CORE][WEBUI] Adding a pop-up containing the full job description when it is very long In some case the job description will be very long, such as a long sql. refer to #3718 This PR add a pop-up for job description when it is long. ![image](https://cloud.githubusercontent.com/assets/7018048/5847400/c757cbbc-a207-11e4-891f-528821c2e68d.png) ![image](https://cloud.githubusercontent.com/assets/7018048/5847409/d434b2b4-a207-11e4-8813-03a74b43d766.png) Author: wangfei Closes #3819 from scwf/popup-descrip-ui and squashes the following commits: ba02b83 [wangfei] address comments a7c5e7b [wangfei] spot that it's been truncated fbf6162 [wangfei] Merge branch 'master' into popup-descrip-ui 0bca96d [wangfei] remove no use val 4b55c3b [wangfei] fix style issue 353c6f4 [wangfei] pop up the description of job with a styled read-only text form field --- .../main/resources/org/apache/spark/ui/static/webui.css | 8 ++++++++ .../main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala | 2 +- .../main/scala/org/apache/spark/ui/jobs/StageTable.scala | 3 +-- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index f02b035a980b1..a1f7133f897ee 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -121,6 +121,14 @@ pre { border: none; } +.description-input { + overflow: hidden; + text-overflow: ellipsis; + width: 100%; + white-space: nowrap; + display: block; +} + .stacktrace-details { max-height: 300px; overflow-y: auto; diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 81212708ba524..045c69da06feb 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -64,7 +64,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { {job.jobId} {job.jobGroup.map(id => s"($id)").getOrElse("")} -
{lastStageDescription}
+ {lastStageDescription} {lastStageName} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index e7d6244dcd679..703d43f9c640d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -112,9 +112,8 @@ private[ui] class StageTableBase( stageData <- listener.stageIdToData.get((s.stageId, s.attemptId)) desc <- stageData.description } yield { -
{desc}
+ {desc} } -
{stageDesc.getOrElse("")} {killLink} {nameLink} {details}
} From 9bad062268676aaa66dcbddd1e0ab7f2d7742425 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 21 Jan 2015 16:51:42 -0800 Subject: [PATCH 21/27] [SPARK-5355] make SparkConf thread-safe The SparkConf is not thread-safe, but is accessed by many threads. The getAll() could return parts of the configs if another thread is access it. This PR changes SparkConf.settings to a thread-safe TrieMap. Author: Davies Liu Closes #4143 from davies/safe-conf and squashes the following commits: f8fa1cf [Davies Liu] change to TrieMap a1d769a [Davies Liu] make SparkConf thread-safe --- core/src/main/scala/org/apache/spark/SparkConf.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index a0ce107f43b16..f9d4aa4240e9d 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -18,6 +18,7 @@ package org.apache.spark import scala.collection.JavaConverters._ +import scala.collection.concurrent.TrieMap import scala.collection.mutable.{HashMap, LinkedHashSet} import org.apache.spark.serializer.KryoSerializer @@ -46,7 +47,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Create a SparkConf that loads defaults from system properties and the classpath */ def this() = this(true) - private[spark] val settings = new HashMap[String, String]() + private[spark] val settings = new TrieMap[String, String]() if (loadDefaults) { // Load any spark.* system properties @@ -177,7 +178,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } /** Get all parameters as a list of pairs */ - def getAll: Array[(String, String)] = settings.clone().toArray + def getAll: Array[(String, String)] = settings.toArray /** Get a parameter as an integer, falling back to a default if not set */ def getInt(key: String, defaultValue: Int): Int = { From 27bccc5ea9742ca166f6026033e7771c8a90cc0a Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 21 Jan 2015 17:34:18 -0800 Subject: [PATCH 22/27] [SPARK-5202] [SQL] Add hql variable substitution support https://cwiki.apache.org/confluence/display/Hive/LanguageManual+VariableSubstitution This is a block issue for the CLI user, it impacts the existed hql scripts from Hive. Author: Cheng Hao Closes #4003 from chenghao-intel/substitution and squashes the following commits: bb41fd6 [Cheng Hao] revert the removed the implicit conversion af7c31a [Cheng Hao] add hql variable substitution support --- .../apache/spark/sql/hive/HiveContext.scala | 6 ++++-- .../sql/hive/execution/SQLQuerySuite.scala | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 274f83af5ac03..9d2cfd8e0d669 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.processors._ +import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} @@ -66,11 +67,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { new this.QueryExecution { val logical = plan } override def sql(sqlText: String): SchemaRDD = { + val substituted = new VariableSubstitution().substitute(hiveconf, sqlText) // TODO: Create a framework for registering parsers instead of just hardcoding if statements. if (conf.dialect == "sql") { - super.sql(sqlText) + super.sql(substituted) } else if (conf.dialect == "hiveql") { - new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(sqlText))) + new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted))) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f6bf2dbb5d6e4..7f9f1ac7cd80d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -104,6 +104,24 @@ class SQLQuerySuite extends QueryTest { ) } + test("command substitution") { + sql("set tbl=src") + checkAnswer( + sql("SELECT key FROM ${hiveconf:tbl} ORDER BY key, value limit 1"), + sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) + + sql("set hive.variable.substitute=false") // disable the substitution + sql("set tbl2=src") + intercept[Exception] { + sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1").collect() + } + + sql("set hive.variable.substitute=true") // enable the substitution + checkAnswer( + sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1"), + sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) + } + test("ordering not in select") { checkAnswer( sql("SELECT key FROM src ORDER BY value"), From ca7910d6dd7693be2a675a0d6a6fcc9eb0aaeb5d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 21 Jan 2015 21:20:31 -0800 Subject: [PATCH 23/27] [SPARK-3424][MLLIB] cache point distances during k-means|| init This PR ports the following feature implemented in #2634 by derrickburns: * During k-means|| initialization, we should cache costs (squared distances) previously computed. It also contains the following optimization: * aggregate sumCosts directly * ran multiple (#runs) k-means++ in parallel I compared the performance locally on mnist-digit. Before this patch: ![before](https://cloud.githubusercontent.com/assets/829644/5845647/93080862-a172-11e4-9a35-044ec711afc4.png) with this patch: ![after](https://cloud.githubusercontent.com/assets/829644/5845653/a47c29e8-a172-11e4-8e9f-08db57fe3502.png) It is clear that each k-means|| iteration takes about the same amount of time with this patch. Authors: Derrick Burns Xiangrui Meng Closes #4144 from mengxr/SPARK-3424-kmeans-parallel and squashes the following commits: 0a875ec [Xiangrui Meng] address comments 4341bb8 [Xiangrui Meng] do not re-compute point distances during k-means|| --- .../spark/mllib/clustering/KMeans.scala | 65 ++++++++++++++----- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 6b5c934f015ba..fc46da3a93425 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -279,45 +279,80 @@ class KMeans private ( */ private def initKMeansParallel(data: RDD[VectorWithNorm]) : Array[Array[VectorWithNorm]] = { - // Initialize each run's center to a random point + // Initialize empty centers and point costs. + val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm]) + var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache() + + // Initialize each run's first center to a random point. val seed = new XORShiftRandom(this.seed).nextInt() val sample = data.takeSample(true, runs, seed).toSeq - val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) + val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) + + /** Merges new centers to centers. */ + def mergeNewCenters(): Unit = { + var r = 0 + while (r < runs) { + centers(r) ++= newCenters(r) + newCenters(r).clear() + r += 1 + } + } // On each step, sample 2 * k points on average for each run with probability proportional - // to their squared distance from that run's current centers + // to their squared distance from that run's centers. Note that only distances between points + // and new centers are computed in each iteration. var step = 0 while (step < initializationSteps) { - val bcCenters = data.context.broadcast(centers) - val sumCosts = data.flatMap { point => - (0 until runs).map { r => - (r, KMeans.pointCost(bcCenters.value(r), point)) - } - }.reduceByKey(_ + _).collectAsMap() - val chosen = data.mapPartitionsWithIndex { (index, points) => + val bcNewCenters = data.context.broadcast(newCenters) + val preCosts = costs + costs = data.zip(preCosts).map { case (point, cost) => + Vectors.dense( + Array.tabulate(runs) { r => + math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r)) + }) + }.cache() + val sumCosts = costs + .aggregate(Vectors.zeros(runs))( + seqOp = (s, v) => { + // s += v + axpy(1.0, v, s) + s + }, + combOp = (s0, s1) => { + // s0 += s1 + axpy(1.0, s1, s0) + s0 + } + ) + preCosts.unpersist(blocking = false) + val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) - points.flatMap { p => + pointsWithCosts.flatMap { case (p, c) => (0 until runs).filter { r => - rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r) + rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) }.map((_, p)) } }.collect() + mergeNewCenters() chosen.foreach { case (r, p) => - centers(r) += p.toDense + newCenters(r) += p.toDense } step += 1 } + mergeNewCenters() + costs.unpersist(blocking = false) + // Finally, we might have a set of more than k candidate centers for each run; weigh each // candidate by the number of points in the dataset mapping to it and run a local k-means++ // on the weighted centers to pick just k of them val bcCenters = data.context.broadcast(centers) val weightMap = data.flatMap { p => - (0 until runs).map { r => + Iterator.tabulate(runs) { r => ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0) } }.reduceByKey(_ + _).collectAsMap() - val finalCenters = (0 until runs).map { r => + val finalCenters = (0 until runs).par.map { r => val myCenters = centers(r).toArray val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30) From fcb3e1862ffe784f39bde467e8d24c1b7ed3afbb Mon Sep 17 00:00:00 2001 From: Basin Date: Wed, 21 Jan 2015 23:06:34 -0800 Subject: [PATCH 24/27] [SPARK-5317]Set BoostingStrategy.defaultParams With Enumeration Algo.Classification or Algo.Regression JIRA Issue: https://issues.apache.org/jira/browse/SPARK-5317 When setting the BoostingStrategy.defaultParams("Classification"), It's more straightforward to set it with the Enumeration Algo.Classification, just like BoostingStragety.defaultParams(Algo.Classification). I overload the method BoostingStragety.defaultParams(). Author: Basin Closes #4103 from Peishen-Jia/stragetyAlgo and squashes the following commits: 87bab1c [Basin] Docs and Code documentations updated. 3b72875 [Basin] defaultParams(algoStr: String) call defaultParams(algo: Algo). 7c1e6ee [Basin] Doc of Java updated. algo -> algoStr instead. d5c8a2e [Basin] Merge branch 'stragetyAlgo' of github.com:Peishen-Jia/spark into stragetyAlgo 65f96ce [Basin] mllib-ensembles doc modified. e04a5aa [Basin] boostingstrategy.defaultParam string algo to enumeration. 68cf544 [Basin] mllib-ensembles doc modified. a4aea51 [Basin] boostingstrategy.defaultParam string algo to enumeration. --- .../tree/configuration/BoostingStrategy.scala | 25 +++++++++++++------ .../mllib/tree/configuration/Strategy.scala | 14 ++++++++--- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index cf51d041c65a9..ed8e6a796f8c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -68,6 +68,15 @@ case class BoostingStrategy( @Experimental object BoostingStrategy { + /** + * Returns default configuration for the boosting algorithm + * @param algo Learning goal. Supported: "Classification" or "Regression" + * @return Configuration for boosting algorithm + */ + def defaultParams(algo: String): BoostingStrategy = { + defaultParams(Algo.fromString(algo)) + } + /** * Returns default configuration for the boosting algorithm * @param algo Learning goal. Supported: @@ -75,15 +84,15 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ - def defaultParams(algo: String): BoostingStrategy = { - val treeStrategy = Strategy.defaultStrategy(algo) - treeStrategy.maxDepth = 3 + def defaultParams(algo: Algo): BoostingStrategy = { + val treeStragtegy = Strategy.defaultStategy(algo) + treeStragtegy.maxDepth = 3 algo match { - case "Classification" => - treeStrategy.numClasses = 2 - new BoostingStrategy(treeStrategy, LogLoss) - case "Regression" => - new BoostingStrategy(treeStrategy, SquaredError) + case Algo.Classification => + treeStragtegy.numClasses = 2 + new BoostingStrategy(treeStragtegy, LogLoss) + case Algo.Regression => + new BoostingStrategy(treeStragtegy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index d5cd89ab94e81..972959885f396 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -173,11 +173,19 @@ object Strategy { * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algo "Classification" or "Regression" */ - def defaultStrategy(algo: String): Strategy = algo match { - case "Classification" => + def defaultStrategy(algo: String): Strategy = { + defaultStategy(Algo.fromString(algo)) + } + + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo Algo.Classification or Algo.Regression + */ + def defaultStategy(algo: Algo): Strategy = algo match { + case Algo.Classification => new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, numClasses = 2) - case "Regression" => + case Algo.Regression => new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, numClasses = 0) } From 3027f06b4127ab23a43c5ce8cebf721e3b6766e5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 21 Jan 2015 23:41:44 -0800 Subject: [PATCH 25/27] [SPARK-5147][Streaming] Delete the received data WAL log periodically This is a refactored fix based on jerryshao 's PR #4037 This enabled deletion of old WAL files containing the received block data. Improvements over #4037 - Respecting the rememberDuration of all receiver streams. In #4037, if there were two receiver streams with multiple remember durations, the deletion would have delete based on the shortest remember duration, thus deleting data prematurely for the receiver stream with longer remember duration. - Added unit test to test creation of receiver WAL, automatic deletion, and respecting of remember duration. jerryshao I am going to merge this ASAP to make it 1.2.1 Thanks for the initial draft of this PR. Made my job much easier. Author: Tathagata Das Author: jerryshao Closes #4149 from tdas/SPARK-5147 and squashes the following commits: 730798b [Tathagata Das] Added comments. c4cf067 [Tathagata Das] Minor fixes 2579b27 [Tathagata Das] Refactored the fix to make sure that the cleanup respects the remember duration of all the receiver streams 2736fd1 [jerryshao] Delete the old WAL log periodically --- .../apache/spark/streaming/DStreamGraph.scala | 8 + .../dstream/ReceiverInputDStream.scala | 11 -- .../streaming/receiver/ReceiverMessage.scala | 5 +- .../receiver/ReceiverSupervisorImpl.scala | 9 + .../streaming/scheduler/JobGenerator.scala | 11 +- .../scheduler/ReceivedBlockTracker.scala | 1 - .../streaming/scheduler/ReceiverTracker.scala | 18 +- .../spark/streaming/util/HdfsUtils.scala | 2 +- .../spark/streaming/ReceiverSuite.scala | 157 ++++++++++++++---- 9 files changed, 172 insertions(+), 50 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index e59c24adb84af..0e285d6088ec1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -160,6 +160,14 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { } } + /** + * Get the maximum remember duration across all the input streams. This is a conservative but + * safe remember duration which can be used to perform cleanup operations. + */ + def getMaxInputStreamRememberDuration(): Duration = { + inputStreams.map { _.rememberDuration }.maxBy { _.milliseconds } + } + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { logDebug("DStreamGraph.writeObject used") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index afd3c4bc4c4fe..8be04314c4285 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -94,15 +94,4 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont } Some(blockRDD) } - - /** - * Clear metadata that are older than `rememberDuration` of this DStream. - * This is an internal method that should not be called directly. This - * implementation overrides the default implementation to clear received - * block information. - */ - private[streaming] override def clearMetadata(time: Time) { - super.clearMetadata(time) - ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration) - } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala index ab9fa192191aa..7bf3c33319491 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala @@ -17,7 +17,10 @@ package org.apache.spark.streaming.receiver -/** Messages sent to the NetworkReceiver. */ +import org.apache.spark.streaming.Time + +/** Messages sent to the Receiver. */ private[streaming] sealed trait ReceiverMessage extends Serializable private[streaming] object StopReceiver extends ReceiverMessage +private[streaming] case class CleanupOldBlocks(threshTime: Time) extends ReceiverMessage diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index d7229c2b96d0b..716cf2c7f32fc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -82,6 +83,9 @@ private[streaming] class ReceiverSupervisorImpl( case StopReceiver => logInfo("Received stop signal") stop("Stopped by driver", None) + case CleanupOldBlocks(threshTime) => + logDebug("Received delete old batch signal") + cleanupOldBlocks(threshTime) } def ref = self @@ -193,4 +197,9 @@ private[streaming] class ReceiverSupervisorImpl( /** Generate new block ID */ private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement) + + private def cleanupOldBlocks(cleanupThreshTime: Time): Unit = { + logDebug(s"Cleaning up blocks older then $cleanupThreshTime") + receivedBlockHandler.cleanupOldBlocks(cleanupThreshTime.milliseconds) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 39b66e1130768..d86f852aba97e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -238,13 +238,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream metadata for the given `time`. */ private def clearMetadata(time: Time) { ssc.graph.clearMetadata(time) - jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration) // If checkpointing is enabled, then checkpoint, // else mark batch to be fully processed if (shouldCheckpoint) { eventActor ! DoCheckpoint(time) } else { + // If checkpointing is not enabled, then delete metadata information about + // received blocks (block data not saved in any case). Otherwise, wait for + // checkpointing of this batch to complete. + val maxRememberDuration = graph.getMaxInputStreamRememberDuration() + jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) markBatchFullyProcessed(time) } } @@ -252,6 +256,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream checkpoint data for the given `time`. */ private def clearCheckpointData(time: Time) { ssc.graph.clearCheckpointData(time) + + // All the checkpoint information about which batches have been processed, etc have + // been saved to checkpoints, so its safe to delete block metadata and data WAL files + val maxRememberDuration = graph.getMaxInputStreamRememberDuration() + jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) markBatchFullyProcessed(time) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index c3d9d7b6813d3..ef23b5c79f2e1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -150,7 +150,6 @@ private[streaming] class ReceivedBlockTracker( writeToLog(BatchCleanupEvent(timesToCleanup)) timeToAllocatedBlocks --= timesToCleanup logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds, waitForCompletion)) - log } /** Stop the block tracker. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 8dbb42a86e3bd..4f998869731ed 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -24,9 +24,8 @@ import scala.language.existentials import akka.actor._ import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} -import org.apache.spark.SparkContext._ import org.apache.spark.streaming.{StreamingContext, Time} -import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver} +import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver} /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -119,9 +118,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } } - /** Clean up metadata older than the given threshold time */ - def cleanupOldMetadata(cleanupThreshTime: Time) { + /** + * Clean up the data and metadata of blocks and batches that are strictly + * older than the threshold time. Note that this does not + */ + def cleanupOldBlocksAndBatches(cleanupThreshTime: Time) { + // Clean up old block and batch metadata receivedBlockTracker.cleanupOldBatches(cleanupThreshTime, waitForCompletion = false) + + // Signal the receivers to delete old block data + if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + logInfo(s"Cleanup old received batch data: $cleanupThreshTime") + receiverInfo.values.flatMap { info => Option(info.actor) } + .foreach { _ ! CleanupOldBlocks(cleanupThreshTime) } + } } /** Register a receiver */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala index 27a28bab83ed5..858ba3c9eb4e5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala @@ -63,7 +63,7 @@ private[streaming] object HdfsUtils { } def getFileSystemForPath(path: Path, conf: Configuration): FileSystem = { - // For local file systems, return the raw loca file system, such calls to flush() + // For local file systems, return the raw local file system, such calls to flush() // actually flushes the stream. val fs = path.getFileSystem(conf) fs match { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index e26c0c6859e57..e8c34a9ee40b9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -17,21 +17,26 @@ package org.apache.spark.streaming +import java.io.File import java.nio.ByteBuffer import java.util.concurrent.Semaphore +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.SparkConf -import org.apache.spark.storage.{StorageLevel, StreamBlockId} -import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver, ReceiverSupervisor} -import org.scalatest.FunSuite +import com.google.common.io.Files import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.receiver._ +import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ + /** Testsuite for testing the network receiver behavior */ -class ReceiverSuite extends FunSuite with Timeouts { +class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { test("receiver life cycle") { @@ -192,7 +197,6 @@ class ReceiverSuite extends FunSuite with Timeouts { val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3 val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1 val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",") - println(minExpectedMessagesPerBlock, maxExpectedMessagesPerBlock, ":", receivedBlockSizes) assert( // the first and last block may be incomplete, so we slice them out recordedBlocks.drop(1).dropRight(1).forall { block => @@ -203,39 +207,91 @@ class ReceiverSuite extends FunSuite with Timeouts { ) } - /** - * An implementation of NetworkReceiver that is used for testing a receiver's life cycle. + * Test whether write ahead logs are generated by received, + * and automatically cleaned up. The clean up must be aware of the + * remember duration of the input streams. E.g., input streams on which window() + * has been applied must remember the data for longer, and hence corresponding + * WALs should be cleaned later. */ - class FakeReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) { - @volatile var otherThread: Thread = null - @volatile var receiving = false - @volatile var onStartCalled = false - @volatile var onStopCalled = false - - def onStart() { - otherThread = new Thread() { - override def run() { - receiving = true - while(!isStopped()) { - Thread.sleep(10) - } + test("write ahead log - generating and cleaning") { + val sparkConf = new SparkConf() + .setMaster("local[4]") // must be at least 3 as we are going to start 2 receivers + .setAppName(framework) + .set("spark.ui.enabled", "true") + .set("spark.streaming.receiver.writeAheadLog.enable", "true") + .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + val batchDuration = Milliseconds(500) + val tempDirectory = Files.createTempDir() + val logDirectory1 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 0)) + val logDirectory2 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 1)) + val allLogFiles1 = new mutable.HashSet[String]() + val allLogFiles2 = new mutable.HashSet[String]() + logInfo("Temp checkpoint directory = " + tempDirectory) + + def getBothCurrentLogFiles(): (Seq[String], Seq[String]) = { + (getCurrentLogFiles(logDirectory1), getCurrentLogFiles(logDirectory2)) + } + + def getCurrentLogFiles(logDirectory: File): Seq[String] = { + try { + if (logDirectory.exists()) { + logDirectory1.listFiles().filter { _.getName.startsWith("log") }.map { _.toString } + } else { + Seq.empty } + } catch { + case e: Exception => + Seq.empty } - onStartCalled = true - otherThread.start() - } - def onStop() { - onStopCalled = true - otherThread.join() + def printLogFiles(message: String, files: Seq[String]) { + logInfo(s"$message (${files.size} files):\n" + files.mkString("\n")) } - def reset() { - receiving = false - onStartCalled = false - onStopCalled = false + withStreamingContext(new StreamingContext(sparkConf, batchDuration)) { ssc => + tempDirectory.deleteOnExit() + val receiver1 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) + val receiver2 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) + val receiverStream1 = ssc.receiverStream(receiver1) + val receiverStream2 = ssc.receiverStream(receiver2) + receiverStream1.register() + receiverStream2.window(batchDuration * 6).register() // 3 second window + ssc.checkpoint(tempDirectory.getAbsolutePath()) + ssc.start() + + // Run until sufficient WAL files have been generated and + // the first WAL files has been deleted + eventually(timeout(20 seconds), interval(batchDuration.milliseconds millis)) { + val (logFiles1, logFiles2) = getBothCurrentLogFiles() + allLogFiles1 ++= logFiles1 + allLogFiles2 ++= logFiles2 + if (allLogFiles1.size > 0) { + assert(!logFiles1.contains(allLogFiles1.toSeq.sorted.head)) + } + if (allLogFiles2.size > 0) { + assert(!logFiles2.contains(allLogFiles2.toSeq.sorted.head)) + } + assert(allLogFiles1.size >= 7) + assert(allLogFiles2.size >= 7) + } + ssc.stop(stopSparkContext = true, stopGracefully = true) + + val sortedAllLogFiles1 = allLogFiles1.toSeq.sorted + val sortedAllLogFiles2 = allLogFiles2.toSeq.sorted + val (leftLogFiles1, leftLogFiles2) = getBothCurrentLogFiles() + + printLogFiles("Receiver 0: all", sortedAllLogFiles1) + printLogFiles("Receiver 0: left", leftLogFiles1) + printLogFiles("Receiver 1: all", sortedAllLogFiles2) + printLogFiles("Receiver 1: left", leftLogFiles2) + + // Verify that necessary latest log files are not deleted + // receiverStream1 needs to retain just the last batch = 1 log file + // receiverStream2 needs to retain 3 seconds (3-seconds window) = 3 log files + assert(sortedAllLogFiles1.takeRight(1).forall(leftLogFiles1.contains)) + assert(sortedAllLogFiles2.takeRight(3).forall(leftLogFiles2.contains)) } } @@ -315,3 +371,42 @@ class ReceiverSuite extends FunSuite with Timeouts { } } +/** + * An implementation of Receiver that is used for testing a receiver's life cycle. + */ +class FakeReceiver(sendData: Boolean = false) extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + @volatile var otherThread: Thread = null + @volatile var receiving = false + @volatile var onStartCalled = false + @volatile var onStopCalled = false + + def onStart() { + otherThread = new Thread() { + override def run() { + receiving = true + var count = 0 + while(!isStopped()) { + if (sendData) { + store(count) + count += 1 + } + Thread.sleep(10) + } + } + } + onStartCalled = true + otherThread.start() + } + + def onStop() { + onStopCalled = true + otherThread.join() + } + + def reset() { + receiving = false + onStartCalled = false + onStopCalled = false + } +} + From 246111d179a2f3f6b97a5c2b121d8ddbfd1c9aad Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Jan 2015 08:16:35 -0800 Subject: [PATCH 26/27] [SPARK-5365][MLlib] Refactor KMeans to reduce redundant data If a point is selected as new centers for many runs, it would collect many redundant data. This pr refactors it. Author: Liang-Chi Hsieh Closes #4159 from viirya/small_refactor_kmeans and squashes the following commits: 25487e6 [Liang-Chi Hsieh] Refactor codes to reduce redundant data. --- .../scala/org/apache/spark/mllib/clustering/KMeans.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index fc46da3a93425..11633e8242313 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -328,14 +328,15 @@ class KMeans private ( val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) pointsWithCosts.flatMap { case (p, c) => - (0 until runs).filter { r => + val rs = (0 until runs).filter { r => rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) - }.map((_, p)) + } + if (rs.length > 0) Some(p, rs) else None } }.collect() mergeNewCenters() - chosen.foreach { case (r, p) => - newCenters(r) += p.toDense + chosen.foreach { case (p, rs) => + rs.foreach(newCenters(_) += p.toDense) } step += 1 } From 820ce03597350257abe0c5c96435c555038e3e6c Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Thu, 22 Jan 2015 13:49:35 -0600 Subject: [PATCH 27/27] SPARK-5370. [YARN] Remove some unnecessary synchronization in YarnAlloca... ...tor Author: Sandy Ryza Closes #4164 from sryza/sandy-spark-5370 and squashes the following commits: 0c8d736 [Sandy Ryza] SPARK-5370. [YARN] Remove some unnecessary synchronization in YarnAllocator --- .../spark/deploy/yarn/YarnAllocator.scala | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 4c35b60c57df3..d00f29665a58f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -60,7 +60,6 @@ private[yarn] class YarnAllocator( import YarnAllocator._ - // These two complementary data structures are locked on allocatedHostToContainersMap. // Visible for testing. val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]] @@ -355,20 +354,18 @@ private[yarn] class YarnAllocator( } } - allocatedHostToContainersMap.synchronized { - if (allocatedContainerToHostMap.containsKey(containerId)) { - val host = allocatedContainerToHostMap.get(containerId).get - val containerSet = allocatedHostToContainersMap.get(host).get + if (allocatedContainerToHostMap.containsKey(containerId)) { + val host = allocatedContainerToHostMap.get(containerId).get + val containerSet = allocatedHostToContainersMap.get(host).get - containerSet.remove(containerId) - if (containerSet.isEmpty) { - allocatedHostToContainersMap.remove(host) - } else { - allocatedHostToContainersMap.update(host, containerSet) - } - - allocatedContainerToHostMap.remove(containerId) + containerSet.remove(containerId) + if (containerSet.isEmpty) { + allocatedHostToContainersMap.remove(host) + } else { + allocatedHostToContainersMap.update(host, containerSet) } + + allocatedContainerToHostMap.remove(containerId) } } }