From a0cb111b22cb093e86b0daeecb3dcc41d095df40 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Sun, 5 Jul 2015 20:50:02 -0700 Subject: [PATCH 001/149] [SPARK-8549] [SPARKR] Fix the line length of SparkR [[SPARK-8549] Fix the line length of SparkR - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-8549) Author: Yu ISHIKAWA Closes #7204 from yu-iskw/SPARK-8549 and squashes the following commits: 6fb131a [Yu ISHIKAWA] Fix the typo 1737598 [Yu ISHIKAWA] [SPARK-8549][SparkR] Fix the line length of SparkR --- R/pkg/R/generics.R | 3 ++- R/pkg/R/pairRDD.R | 12 ++++++------ R/pkg/R/sparkR.R | 9 ++++++--- R/pkg/R/utils.R | 31 +++++++++++++++++------------- R/pkg/inst/tests/test_includeJAR.R | 4 ++-- R/pkg/inst/tests/test_rdd.R | 12 ++++++++---- R/pkg/inst/tests/test_sparkSQL.R | 11 +++++++++-- 7 files changed, 51 insertions(+), 31 deletions(-) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 79055b7f18558..fad9d71158c51 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -20,7 +20,8 @@ # @rdname aggregateRDD # @seealso reduce # @export -setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) +setGeneric("aggregateRDD", + function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) # @rdname cache-methods # @export diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 7f902ba8e683e..0f1179e0aa51a 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -560,8 +560,8 @@ setMethod("join", # Left outer join two RDDs # # @description -# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -597,8 +597,8 @@ setMethod("leftOuterJoin", # Right outer join two RDDs # # @description -# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -634,8 +634,8 @@ setMethod("rightOuterJoin", # Full outer join two RDDs # # @description -# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 86233e01db365..048eb8ed541e4 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -105,7 +105,8 @@ sparkR.init <- function( sparkPackages = "") { if (exists(".sparkRjsc", envir = .sparkREnv)) { - cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n") + cat(paste("Re-using existing Spark Context.", + "Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n")) return(get(".sparkRjsc", envir = .sparkREnv)) } @@ -180,14 +181,16 @@ sparkR.init <- function( sparkExecutorEnvMap <- new.env() if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { - sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) + sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- + paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) } for (varname in names(sparkExecutorEnv)) { sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] } nonEmptyJars <- Filter(function(x) { x != "" }, jars) - localJarPaths <- sapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) + localJarPaths <- sapply(nonEmptyJars, + function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) # Set the start time to identify jobjs # Seconds resolution is good enough for this purpose, so use ints diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 13cec0f712fb4..ea629a64f7158 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -334,18 +334,21 @@ getStorageLevel <- function(newLevel = c("DISK_ONLY", "MEMORY_ONLY_SER_2", "OFF_HEAP")) { match.arg(newLevel) + storageLevelClass <- "org.apache.spark.storage.StorageLevel" storageLevel <- switch(newLevel, - "DISK_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY"), - "DISK_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY_2"), - "MEMORY_AND_DISK" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK"), - "MEMORY_AND_DISK_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_2"), - "MEMORY_AND_DISK_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER"), - "MEMORY_AND_DISK_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER_2"), - "MEMORY_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY"), - "MEMORY_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_2"), - "MEMORY_ONLY_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER"), - "MEMORY_ONLY_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER_2"), - "OFF_HEAP" = callJStatic("org.apache.spark.storage.StorageLevel", "OFF_HEAP")) + "DISK_ONLY" = callJStatic(storageLevelClass, "DISK_ONLY"), + "DISK_ONLY_2" = callJStatic(storageLevelClass, "DISK_ONLY_2"), + "MEMORY_AND_DISK" = callJStatic(storageLevelClass, "MEMORY_AND_DISK"), + "MEMORY_AND_DISK_2" = callJStatic(storageLevelClass, "MEMORY_AND_DISK_2"), + "MEMORY_AND_DISK_SER" = callJStatic(storageLevelClass, + "MEMORY_AND_DISK_SER"), + "MEMORY_AND_DISK_SER_2" = callJStatic(storageLevelClass, + "MEMORY_AND_DISK_SER_2"), + "MEMORY_ONLY" = callJStatic(storageLevelClass, "MEMORY_ONLY"), + "MEMORY_ONLY_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_2"), + "MEMORY_ONLY_SER" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER"), + "MEMORY_ONLY_SER_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER_2"), + "OFF_HEAP" = callJStatic(storageLevelClass, "OFF_HEAP")) } # Utility function for functions where an argument needs to be integer but we want to allow @@ -545,9 +548,11 @@ mergePartitions <- function(rdd, zip) { lengthOfKeys <- part[[len - lengthOfValues]] stopifnot(len == lengthOfKeys + lengthOfValues) - # For zip operation, check if corresponding partitions of both RDDs have the same number of elements. + # For zip operation, check if corresponding partitions + # of both RDDs have the same number of elements. if (zip && lengthOfKeys != lengthOfValues) { - stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") + stop(paste("Can only zip RDDs with same number of elements", + "in each pair of corresponding partitions.")) } if (lengthOfKeys > 1) { diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/test_includeJAR.R index 844d86f3cc97f..cc1faeabffe30 100644 --- a/R/pkg/inst/tests/test_includeJAR.R +++ b/R/pkg/inst/tests/test_includeJAR.R @@ -18,8 +18,8 @@ context("include an external JAR in SparkContext") runScript <- function() { sparkHome <- Sys.getenv("SPARK_HOME") - jarPath <- paste("--jars", - shQuote(file.path(sparkHome, "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar"))) + sparkTestJarPath <- "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar" + jarPath <- paste("--jars", shQuote(file.path(sparkHome, sparkTestJarPath))) scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/jarTest.R") submitPath <- file.path(sparkHome, "bin/spark-submit") res <- system2(command = submitPath, diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index fc3c01d837de4..b79692873cec3 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -669,13 +669,15 @@ test_that("fullOuterJoin() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3))) rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) - expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), + list(2, list(NULL, 4)), list(3, list(3, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1))) rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) - expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), + list("a", list(3, 1)), list("c", list(1, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -683,13 +685,15 @@ test_that("fullOuterJoin() on pairwise RDDs", { rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), - sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), + list(3, list(NULL, 3)), list(4, list(NULL, 4))))) rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), - sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3))))) + sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), + list("d", list(NULL, 4)), list("c", list(NULL, 3))))) }) test_that("sortByKey() on pairwise RDDs", { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 0e4235ea8b4b3..b0ea38854304e 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -391,7 +391,7 @@ test_that("collect() and take() on a DataFrame return the same number of rows an expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) -test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { +test_that("multiple pipeline transformations result in an RDD with the correct values", { df <- jsonFile(sqlContext, jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 @@ -756,7 +756,14 @@ test_that("toJSON() returns an RDD of the correct values", { test_that("showDF()", { df <- jsonFile(sqlContext, jsonPath) s <- capture.output(showDF(df)) - expect_output(s , "+----+-------+\n| age| name|\n+----+-------+\n|null|Michael|\n| 30| Andy|\n| 19| Justin|\n+----+-------+\n") + expected <- paste("+----+-------+\n", + "| age| name|\n", + "+----+-------+\n", + "|null|Michael|\n", + "| 30| Andy|\n", + "| 19| Justin|\n", + "+----+-------+\n", sep="") + expect_output(s , expected) }) test_that("isLocal()", { From 6d0411b4f3a202cfb53f638ee5fd49072b42d3a6 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Sun, 5 Jul 2015 21:50:52 -0700 Subject: [PATCH 002/149] [SQL][Minor] Update the DataFrame API for encode/decode This is a the follow up of #6843. Author: Cheng Hao Closes #7230 from chenghao-intel/str_funcs2_followup and squashes the following commits: 52cc553 [Cheng Hao] update the code as comment --- .../expressions/stringOperations.scala | 21 ++++++++++--------- .../org/apache/spark/sql/functions.scala | 14 +++++++------ .../spark/sql/DataFrameFunctionsSuite.scala | 8 +++++-- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 6de40629ff27e..1a14a7a449342 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -392,12 +392,13 @@ case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInput /** * Decodes the first argument into a String using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. (As of Hive 0.12.0.). + * If either argument is null, the result will also be null. */ -case class Decode(bin: Expression, charset: Expression) extends Expression with ExpectsInputTypes { - override def children: Seq[Expression] = bin :: charset :: Nil - override def foldable: Boolean = bin.foldable && charset.foldable - override def nullable: Boolean = bin.nullable || charset.nullable +case class Decode(bin: Expression, charset: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = bin + override def right: Expression = charset override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType, StringType) @@ -420,13 +421,13 @@ case class Decode(bin: Expression, charset: Expression) extends Expression with /** * Encodes the first argument into a BINARY using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. (As of Hive 0.12.0.) + * If either argument is null, the result will also be null. */ case class Encode(value: Expression, charset: Expression) - extends Expression with ExpectsInputTypes { - override def children: Seq[Expression] = value :: charset :: Nil - override def foldable: Boolean = value.foldable && charset.foldable - override def nullable: Boolean = value.nullable || charset.nullable + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = value + override def right: Expression = charset override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType, StringType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index abcfc0b65020c..f80291776f335 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1666,18 +1666,19 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def encode(value: Column, charset: Column): Column = Encode(value.expr, charset.expr) + def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr) /** * Computes the first argument into a binary from a string using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). * If either argument is null, the result will also be null. + * NOTE: charset represents the string value of the character set, not the column name. * * @group string_funcs * @since 1.5.0 */ - def encode(columnName: String, charsetColumnName: String): Column = - encode(Column(columnName), Column(charsetColumnName)) + def encode(columnName: String, charset: String): Column = + encode(Column(columnName), charset) /** * Computes the first argument into a string from a binary using the provided character set @@ -1687,18 +1688,19 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def decode(value: Column, charset: Column): Column = Decode(value.expr, charset.expr) + def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr) /** * Computes the first argument into a string from a binary using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). * If either argument is null, the result will also be null. + * NOTE: charset represents the string value of the character set, not the column name. * * @group string_funcs * @since 1.5.0 */ - def decode(columnName: String, charsetColumnName: String): Column = - decode(Column(columnName), Column(charsetColumnName)) + def decode(columnName: String, charset: String): Column = + decode(Column(columnName), charset) ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index bc455a922d154..afba28515e032 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -261,11 +261,15 @@ class DataFrameFunctionsSuite extends QueryTest { // non ascii characters are not allowed in the code, so we disable the scalastyle here. val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c") checkAnswer( - df.select(encode($"a", $"b"), encode("a", "b"), decode($"c", $"b"), decode("c", "b")), + df.select( + encode($"a", "utf-8"), + encode("a", "utf-8"), + decode($"c", "utf-8"), + decode("c", "utf-8")), Row(bytes, bytes, "大千世界", "大千世界")) checkAnswer( - df.selectExpr("encode(a, b)", "decode(c, b)"), + df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"), Row(bytes, "大千世界")) // scalastyle:on } From 86768b7b3b0c2964e744bc491bc20a1d3140ce93 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 5 Jul 2015 23:54:25 -0700 Subject: [PATCH 003/149] [SPARK-8831][SQL] Support AbstractDataType in TypeCollection. Otherwise it is impossible to declare an expression supporting DecimalType. Author: Reynold Xin Closes #7232 from rxin/typecollection-adt and squashes the following commits: 934d3d1 [Reynold Xin] [SPARK-8831][SQL] Support AbstractDataType in TypeCollection. --- .../spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 2 -- .../org/apache/spark/sql/types/AbstractDataType.scala | 10 ++++++---- .../sql/catalyst/analysis/HiveTypeCoercionSuite.scala | 6 ++++++ 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 84acc0e7e90ec..5367b7f3308ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -708,8 +708,6 @@ object HiveTypeCoercion { case (NullType, target) => Cast(e, target.defaultConcreteType) // Implicit cast among numeric types - // If input is decimal, and we expect a decimal type, just use the input. - case (_: DecimalType, DecimalType) => e // If input is a numeric type but not decimal, and we expect a decimal type, // cast the input to unlimited precision decimal. case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index ffefb0e7837e9..fb1b47e946214 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -53,10 +53,12 @@ private[sql] abstract class AbstractDataType { * * This means that we prefer StringType over BinaryType if it is possible to cast to StringType. */ -private[sql] class TypeCollection(private val types: Seq[DataType]) extends AbstractDataType { +private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) + extends AbstractDataType { + require(types.nonEmpty, s"TypeCollection ($types) cannot be empty") - private[sql] override def defaultConcreteType: DataType = types.head + private[sql] override def defaultConcreteType: DataType = types.head.defaultConcreteType private[sql] override def isParentOf(childCandidate: DataType): Boolean = false @@ -68,9 +70,9 @@ private[sql] class TypeCollection(private val types: Seq[DataType]) extends Abst private[sql] object TypeCollection { - def apply(types: DataType*): TypeCollection = new TypeCollection(types) + def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) - def unapply(typ: AbstractDataType): Option[Seq[DataType]] = typ match { + def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match { case typ: TypeCollection => Some(typ.types) case _ => None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 67d05ab536b7f..b56426617789e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -71,6 +71,12 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType) shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType) + + shouldCast( + DecimalType.Unlimited, TypeCollection(IntegerType, DecimalType), DecimalType.Unlimited) + shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2)) + shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2)) + shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2)) } test("ineligible implicit type cast") { From 39e4e7e4d89077a637c4cad3a986e0e3447d1ae7 Mon Sep 17 00:00:00 2001 From: Steve Lindemann Date: Mon, 6 Jul 2015 10:17:05 -0700 Subject: [PATCH 004/149] [SPARK-8841] [SQL] Fix partition pruning percentage log message When pruning partitions for a query plan, a message is logged indicating what how many partitions were selected based on predicate criteria, and what percent were pruned. The current release erroneously uses `1 - total/selected` to compute this quantity, leading to nonsense messages like "pruned -1000% partitions". The fix is simple and obvious. Author: Steve Lindemann Closes #7227 from srlindemann/master and squashes the following commits: c788061 [Steve Lindemann] fix percentPruned log message --- .../scala/org/apache/spark/sql/sources/DataSourceStrategy.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index ce16e050c56ed..66f7ba90140b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -65,7 +65,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { logInfo { val total = t.partitionSpec.partitions.length val selected = selectedPartitions.length - val percentPruned = (1 - total.toDouble / selected.toDouble) * 100 + val percentPruned = (1 - selected.toDouble / total.toDouble) * 100 s"Selected $selected partitions out of $total, pruned $percentPruned% partitions." } From 293225e0cd9318ad368dde30ac6a17725d33ebb6 Mon Sep 17 00:00:00 2001 From: "Daniel Emaasit (PhD Student)" Date: Mon, 6 Jul 2015 10:36:02 -0700 Subject: [PATCH 005/149] [SPARK-8124] [SPARKR] Created more examples on SparkR DataFrames Here are more examples on SparkR DataFrames including creating a Spark Contect and a SQL context, loading data and simple data manipulation. Author: Daniel Emaasit (PhD Student) Closes #6668 from Emaasit/dan-dev and squashes the following commits: 3a97867 [Daniel Emaasit (PhD Student)] Used fewer rows for createDataFrame f7227f9 [Daniel Emaasit (PhD Student)] Using command line arguments a550f70 [Daniel Emaasit (PhD Student)] Used base R functions 33f9882 [Daniel Emaasit (PhD Student)] Renamed file b6603e3 [Daniel Emaasit (PhD Student)] changed "Describe" function to "describe" 90565dd [Daniel Emaasit (PhD Student)] Deleted the getting-started file b95a103 [Daniel Emaasit (PhD Student)] Deleted this file cc55cd8 [Daniel Emaasit (PhD Student)] combined all the code into one .R file c6933af [Daniel Emaasit (PhD Student)] changed variable name to SQLContext 8e0fe14 [Daniel Emaasit (PhD Student)] provided two options for creating DataFrames 2653573 [Daniel Emaasit (PhD Student)] Updates to a comment and variable name 275b787 [Daniel Emaasit (PhD Student)] Added the Apache License at the top of the file 2e8f724 [Daniel Emaasit (PhD Student)] Added the Apache License at the top of the file 486f44e [Daniel Emaasit (PhD Student)] Added the Apache License at the file d705112 [Daniel Emaasit (PhD Student)] Created more examples on SparkR DataFrames --- examples/src/main/r/data-manipulation.R | 107 ++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 examples/src/main/r/data-manipulation.R diff --git a/examples/src/main/r/data-manipulation.R b/examples/src/main/r/data-manipulation.R new file mode 100644 index 0000000000000..aa2336e300a91 --- /dev/null +++ b/examples/src/main/r/data-manipulation.R @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# For this example, we shall use the "flights" dataset +# The dataset consists of every flight departing Houston in 2011. +# The data set is made up of 227,496 rows x 14 columns. + +# To run this example use +# ./bin/sparkR --packages com.databricks:spark-csv_2.10:1.0.3 +# examples/src/main/r/data-manipulation.R + +# Load SparkR library into your R session +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 1) { + print("Usage: data-manipulation.R % + summarize(avg(flightsDF$dep_delay), avg(flightsDF$arr_delay)) -> dailyDelayDF + + # Print the computed data frame + head(dailyDelayDF) +} + +# Stop the SparkContext now +sparkR.stop() From 0e194645f42be0d6ac9b5a712f8fc1798418736d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 6 Jul 2015 13:26:46 -0700 Subject: [PATCH 006/149] [SPARK-8837][SPARK-7114][SQL] support using keyword in column name Author: Wenchen Fan Closes #7237 from cloud-fan/parser and squashes the following commits: e7b49bb [Wenchen Fan] support using keyword in column name --- .../apache/spark/sql/catalyst/SqlParser.scala | 28 ++++++++++++------- .../org/apache/spark/sql/SQLQuerySuite.scala | 9 ++++++ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 8d02fbf4f92c4..e8e9b9802e94b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -287,15 +287,18 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { throw new AnalysisException(s"invalid function approximate($floatLit) $udfName") } } - | CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ - (ELSE ~> expression).? <~ END ^^ { - case casePart ~ altPart ~ elsePart => - val branches = altPart.flatMap { case whenExpr ~ thenExpr => - Seq(whenExpr, thenExpr) - } ++ elsePart - casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches)) - } - ) + | CASE ~> whenThenElse ^^ CaseWhen + | CASE ~> expression ~ whenThenElse ^^ + { case keyPart ~ branches => CaseKeyWhen(keyPart, branches) } + ) + + protected lazy val whenThenElse: Parser[List[Expression]] = + rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ (ELSE ~> expression).? <~ END ^^ { + case altPart ~ elsePart => + altPart.flatMap { case whenExpr ~ thenExpr => + Seq(whenExpr, thenExpr) + } ++ elsePart + } protected lazy val cast: Parser[Expression] = CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { @@ -354,6 +357,11 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val signedPrimary: Parser[Expression] = sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e} + protected lazy val attributeName: Parser[String] = acceptMatch("attribute name", { + case lexical.Identifier(str) => str + case lexical.Keyword(str) if !lexical.delimiters.contains(str) => str + }) + protected lazy val primary: PackratParser[Expression] = ( literal | expression ~ ("[" ~> expression <~ "]") ^^ @@ -364,9 +372,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { | "(" ~> expression <~ ")" | function | dotExpressionHeader - | ident ^^ {case i => UnresolvedAttribute.quoted(i)} | signedPrimary | "~" ~> expression ^^ BitwiseNot + | attributeName ^^ UnresolvedAttribute.quoted ) protected lazy val dotExpressionHeader: Parser[Expression] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cc6af1ccc1cce..12ad019e8b473 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1458,4 +1458,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(sql("SELECT * FROM t ORDER BY NULL"), Seq(Row(1, 2), Row(1, 2))) } } + + test("SPARK-8837: use keyword in column name") { + withTempTable("t") { + val df = Seq(1 -> "a").toDF("count", "sort") + checkAnswer(df.filter("count > 0"), Row(1, "a")) + df.registerTempTable("t") + checkAnswer(sql("select count, sort from t"), Row(1, "a")) + } + } } From 57c72fcce75907c08a1ae53a0d85447176fc3c69 Mon Sep 17 00:00:00 2001 From: Dirceu Semighini Filho Date: Mon, 6 Jul 2015 13:28:07 -0700 Subject: [PATCH 007/149] Small update in the readme file Just change the attribute from -PsparkR to -Psparkr Author: Dirceu Semighini Filho Closes #7242 from dirceusemighini/patch-1 and squashes the following commits: fad5991 [Dirceu Semighini Filho] Small update in the readme file --- R/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/README.md b/R/README.md index d7d65b4f0eca5..005f56da1670c 100644 --- a/R/README.md +++ b/R/README.md @@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R #### Build Spark -Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run +Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run ``` build/mvn -DskipTests -Psparkr package ``` From 37e4d92142a6309e2df7d36883e0c7892c3d792d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 6 Jul 2015 13:31:31 -0700 Subject: [PATCH 008/149] [SPARK-8784] [SQL] Add Python API for hex and unhex Add Python API for hex/unhex, also cleanup Hex/Unhex Author: Davies Liu Closes #7223 from davies/hex and squashes the following commits: 6f1249d [Davies Liu] no explicit rule to cast string into binary 711a6ed [Davies Liu] fix test f9fe5a3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into hex f032fbb [Davies Liu] Merge branch 'hex' of github.com:davies/spark into hex 49e325f [Davies Liu] Merge branch 'master' of github.com:apache/spark into hex b31fc9a [Davies Liu] Update math.scala 25156b7 [Davies Liu] address comments and fix test c3af78c [Davies Liu] address commments 1a24082 [Davies Liu] Add Python API for hex and unhex --- python/pyspark/sql/functions.py | 28 +++++++ .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 83 ++++++++++--------- .../expressions/MathFunctionsSuite.scala | 25 ++++-- .../org/apache/spark/sql/functions.scala | 2 +- 5 files changed, 93 insertions(+), 47 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 49dd0332afe74..dca39fa833435 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -395,6 +395,34 @@ def randn(seed=None): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def hex(col): + """Computes hex value of the given column, which could be StringType, + BinaryType, IntegerType or LongType. + + >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() + [Row(hex(a)=u'414243', hex(b)=u'3')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.hex(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def unhex(col): + """Inverse of hex. Interprets each pair of characters as a hexadecimal number + and converts to the byte representation of number. + + >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() + [Row(unhex(a)=bytearray(b'ABC'))] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.unhex(_to_java_column(col)) + return Column(jc) + + @ignore_unicode_prefix @since(1.5) def sha1(col): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 92a50e7092317..fef276353022c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -168,7 +168,7 @@ object FunctionRegistry { expression[Substring]("substring"), expression[UnBase64]("unbase64"), expression[Upper]("ucase"), - expression[UnHex]("unhex"), + expression[Unhex]("unhex"), expression[Upper]("upper"), // datetime functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 45b7e4d3405c8..92500453980f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -298,6 +298,21 @@ case class Bin(child: Expression) } } +object Hex { + val hexDigits = Array[Char]( + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' + ).map(_.toByte) + + // lookup table to translate '0' -> 0 ... 'F'/'f' -> 15 + val unhexDigits = { + val array = Array.fill[Byte](128)(-1) + (0 to 9).foreach(i => array('0' + i) = i.toByte) + (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) + (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) + array + } +} + /** * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. * Otherwise if the number is a STRING, it converts each character into its hex representation @@ -307,7 +322,7 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes // TODO: Create code-gen version. override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, StringType, BinaryType)) + Seq(TypeCollection(LongType, BinaryType, StringType)) override def dataType: DataType = StringType @@ -319,30 +334,18 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes child.dataType match { case LongType => hex(num.asInstanceOf[Long]) case BinaryType => hex(num.asInstanceOf[Array[Byte]]) - case StringType => hex(num.asInstanceOf[UTF8String]) + case StringType => hex(num.asInstanceOf[UTF8String].getBytes) } } } - /** - * Converts every character in s to two hex digits. - */ - private def hex(str: UTF8String): UTF8String = { - hex(str.getBytes) - } - - private def hex(bytes: Array[Byte]): UTF8String = { - doHex(bytes, bytes.length) - } - - private def doHex(bytes: Array[Byte], length: Int): UTF8String = { + private[this] def hex(bytes: Array[Byte]): UTF8String = { + val length = bytes.length val value = new Array[Byte](length * 2) var i = 0 while (i < length) { - value(i * 2) = Character.toUpperCase(Character.forDigit( - (bytes(i) & 0xF0) >>> 4, 16)).toByte - value(i * 2 + 1) = Character.toUpperCase(Character.forDigit( - bytes(i) & 0x0F, 16)).toByte + value(i * 2) = Hex.hexDigits((bytes(i) & 0xF0) >> 4) + value(i * 2 + 1) = Hex.hexDigits(bytes(i) & 0x0F) i += 1 } UTF8String.fromBytes(value) @@ -355,24 +358,23 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes var len = 0 do { len += 1 - value(value.length - len) = - Character.toUpperCase(Character.forDigit((numBuf & 0xF).toInt, 16)).toByte + value(value.length - len) = Hex.hexDigits((numBuf & 0xF).toInt) numBuf >>>= 4 } while (numBuf != 0) UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length)) } } - /** * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. */ -case class UnHex(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTypes { // TODO: Create code-gen version. override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def nullable: Boolean = true override def dataType: DataType = BinaryType override def eval(input: InternalRow): Any = { @@ -384,26 +386,31 @@ case class UnHex(child: Expression) extends UnaryExpression with ExpectsInputTyp } } - private val unhexDigits = { - val array = Array.fill[Byte](128)(-1) - (0 to 9).foreach(i => array('0' + i) = i.toByte) - (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) - (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) - array - } - - private def unhex(inputBytes: Array[Byte]): Array[Byte] = { - var bytes = inputBytes + private[this] def unhex(bytes: Array[Byte]): Array[Byte] = { + val out = new Array[Byte]((bytes.length + 1) >> 1) + var i = 0 if ((bytes.length & 0x01) != 0) { - bytes = '0'.toByte +: bytes + // padding with '0' + if (bytes(0) < 0) { + return null + } + val v = Hex.unhexDigits(bytes(0)) + if (v == -1) { + return null + } + out(0) = v + i += 1 } - val out = new Array[Byte](bytes.length >> 1) // two characters form the hex value. - var i = 0 while (i < bytes.length) { - val first = unhexDigits(bytes(i)) - val second = unhexDigits(bytes(i + 1)) - if (first == -1 || second == -1) { return null} + if (bytes(i) < 0 || bytes(i + 1) < 0) { + return null + } + val first = Hex.unhexDigits(bytes(i)) + val second = Hex.unhexDigits(bytes(i + 1)) + if (first == -1 || second == -1) { + return null + } out(i / 2) = (((first << 4) | second) & 0xFF).toByte i += 2 } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 03d8400cf356b..7ca9e30b2bcd5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -21,8 +21,7 @@ import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{DataType, LongType} -import org.apache.spark.sql.types.{IntegerType, DoubleType} +import org.apache.spark.sql.types._ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -271,20 +270,32 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("hex") { + checkEvaluation(Hex(Literal.create(null, LongType)), null) + checkEvaluation(Hex(Literal(28L)), "1C") + checkEvaluation(Hex(Literal(-28L)), "FFFFFFFFFFFFFFE4") checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") - checkEvaluation(Hex(Literal("helloHex")), "68656C6C6F486578") + checkEvaluation(Hex(Literal.create(null, BinaryType)), null) checkEvaluation(Hex(Literal("helloHex".getBytes())), "68656C6C6F486578") // scalastyle:off // Turn off scala style for non-ascii chars - checkEvaluation(Hex(Literal("三重的")), "E4B889E9878DE79A84") + checkEvaluation(Hex(Literal("三重的".getBytes("UTF8"))), "E4B889E9878DE79A84") // scalastyle:on } test("unhex") { - checkEvaluation(UnHex(Literal("737472696E67")), "string".getBytes) - checkEvaluation(UnHex(Literal("")), new Array[Byte](0)) - checkEvaluation(UnHex(Literal("0")), Array[Byte](0)) + checkEvaluation(Unhex(Literal.create(null, StringType)), null) + checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes) + checkEvaluation(Unhex(Literal("")), new Array[Byte](0)) + checkEvaluation(Unhex(Literal("F")), Array[Byte](15)) + checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1)) + checkEvaluation(Unhex(Literal("GG")), null) + // scalastyle:off + // Turn off scala style for non-ascii chars + checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes("UTF-8")) + checkEvaluation(Unhex(Literal("三重的")), null) + + // scalastyle:on } test("hypot") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f80291776f335..4da9ffc495e17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1095,7 +1095,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def unhex(column: Column): Column = UnHex(column.expr) + def unhex(column: Column): Column = Unhex(column.expr) /** * Inverse of hex. Interprets each pair of characters as a hexadecimal number From 2471c0bf7f463bb144b44a2e51c0f363e71e099d Mon Sep 17 00:00:00 2001 From: kai Date: Mon, 6 Jul 2015 14:33:30 -0700 Subject: [PATCH 009/149] [SPARK-4485] [SQL] 1) Add broadcast hash outer join, (2) Fix SparkPlanTest This pull request (1) extracts common functions used by hash outer joins and put it in interface HashOuterJoin (2) adds ShuffledHashOuterJoin and BroadcastHashOuterJoin (3) adds test cases for shuffled and broadcast hash outer join (3) makes SparkPlanTest to support binary or more complex operators, and fixes bugs in plan composition in SparkPlanTest Author: kai Closes #7162 from kai-zeng/outer and squashes the following commits: 3742359 [kai] Fix not-serializable exception for code-generated keys in broadcasted relations 14e4bf8 [kai] Use CanBroadcast in broadcast outer join planning dc5127e [kai] code style fixes b5a4efa [kai] (1) Add broadcast hash outer join, (2) Fix SparkPlanTest --- .../spark/sql/execution/SparkStrategies.scala | 12 +- .../joins/BroadcastHashOuterJoin.scala | 121 ++++++++++++++++++ .../sql/execution/joins/HashOuterJoin.scala | 95 ++++---------- .../joins/ShuffledHashOuterJoin.scala | 85 ++++++++++++ .../org/apache/spark/sql/JoinSuite.scala | 40 +++++- .../spark/sql/execution/SparkPlanTest.scala | 99 +++++++++++--- .../sql/execution/joins/OuterJoinSuite.scala | 88 +++++++++++++ 7 files changed, 441 insertions(+), 99 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5daf86d817586..32044989044a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -117,8 +117,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + case ExtractEquiJoinKeys( + LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => + joins.BroadcastHashOuterJoin( + leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + + case ExtractEquiJoinKeys( + RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => + joins.BroadcastHashOuterJoin( + leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => - joins.HashOuterJoin( + joins.ShuffledHashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala new file mode 100644 index 0000000000000..5da04c78744d9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.ThreadUtils + +import scala.collection.JavaConversions._ +import scala.concurrent._ +import scala.concurrent.duration._ + +/** + * :: DeveloperApi :: + * Performs a outer hash join for two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * relation is not shuffled. + */ +@DeveloperApi +case class BroadcastHashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashOuterJoin { + + val timeout = { + val timeoutValue = sqlContext.conf.broadcastTimeout + if (timeoutValue < 0) { + Duration.Inf + } else { + timeoutValue.seconds + } + } + + override def requiredChildDistribution: Seq[Distribution] = + UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + + private[this] lazy val (buildPlan, streamedPlan) = joinType match { + case RightOuter => (left, right) + case LeftOuter => (right, left) + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + + private[this] lazy val (buildKeys, streamedKeys) = joinType match { + case RightOuter => (leftKeys, rightKeys) + case LeftOuter => (rightKeys, leftKeys) + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + + @transient + private val broadcastFuture = future { + // Note that we use .execute().collect() because we don't want to convert data to Scala types + val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() + // buildHashTable uses code-generated rows as keys, which are not serializable + val hashed = + buildHashTable(input.iterator, new InterpretedProjection(buildKeys, buildPlan.output)) + sparkContext.broadcast(hashed) + }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) + + override def doExecute(): RDD[InternalRow] = { + val broadcastRelation = Await.result(broadcastFuture, timeout) + + streamedPlan.execute().mapPartitions { streamedIter => + val joinedRow = new JoinedRow() + val hashTable = broadcastRelation.value + val keyGenerator = newProjection(streamedKeys, streamedPlan.output) + + joinType match { + case LeftOuter => + streamedIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST)) + }) + + case RightOuter => + streamedIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + }) + + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + } + } +} + +object BroadcastHashOuterJoin { + + private val broadcastHashOuterJoinExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-outer-join", 128)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index e41538ec1fc1a..886b5fa0c5103 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -19,32 +19,25 @@ package org.apache.spark.sql.execution.joins import java.util.{HashMap => JavaHashMap} -import org.apache.spark.rdd.RDD - -import scala.collection.JavaConversions._ - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer -/** - * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using - * the join keys. This operator requires loading the associated partition in both side into memory. - */ @DeveloperApi -case class HashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - - override def outputPartitioning: Partitioning = joinType match { +trait HashOuterJoin { + self: SparkPlan => + + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val joinType: JoinType + val condition: Option[Expression] + val left: SparkPlan + val right: SparkPlan + +override def outputPartitioning: Partitioning = joinType match { case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) @@ -52,9 +45,6 @@ case class HashOuterJoin( throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } - override def requiredChildDistribution: Seq[ClusteredDistribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output: Seq[Attribute] = { joinType match { case LeftOuter => @@ -68,8 +58,8 @@ case class HashOuterJoin( } } - @transient private[this] lazy val DUMMY_LIST = Seq[InternalRow](null) - @transient private[this] lazy val EMPTY_LIST = Seq.empty[InternalRow] + @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) + @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) @@ -80,7 +70,7 @@ case class HashOuterJoin( // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. - private[this] def leftOuterIterator( + protected[this] def leftOuterIterator( key: InternalRow, joinedRow: JoinedRow, rightIter: Iterable[InternalRow]): Iterator[InternalRow] = { @@ -89,7 +79,7 @@ case class HashOuterJoin( val temp = rightIter.collect { case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() } - if (temp.size == 0) { + if (temp.isEmpty) { joinedRow.withRight(rightNullRow).copy :: Nil } else { temp @@ -101,18 +91,17 @@ case class HashOuterJoin( ret.iterator } - private[this] def rightOuterIterator( + protected[this] def rightOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], joinedRow: JoinedRow): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = leftIter.collect { case l if boundCondition(joinedRow.withLeft(l)) => - joinedRow.copy + joinedRow.copy() } - if (temp.size == 0) { + if (temp.isEmpty) { joinedRow.withLeft(leftNullRow).copy :: Nil } else { temp @@ -124,10 +113,9 @@ case class HashOuterJoin( ret.iterator } - private[this] def fullOuterIterator( + protected[this] def fullOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow], joinedRow: JoinedRow): Iterator[InternalRow] = { - if (!key.anyNull) { // Store the positions of records in right, if one of its associated row satisfy // the join condition. @@ -171,7 +159,7 @@ case class HashOuterJoin( } } - private[this] def buildHashTable( + protected[this] def buildHashTable( iter: Iterator[InternalRow], keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]() @@ -190,43 +178,4 @@ case class HashOuterJoin( hashTable } - - protected override def doExecute(): RDD[InternalRow] = { - val joinedRow = new JoinedRow() - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) - - joinType match { - case LeftOuter => - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - val keyGenerator = newProjection(leftKeys, left.output) - leftIter.flatMap( currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) - }) - - case RightOuter => - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val keyGenerator = newProjection(rightKeys, right.output) - rightIter.flatMap ( currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) - }) - - case FullOuter => - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST), joinedRow) - } - - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala new file mode 100644 index 0000000000000..cfc9c14aaa363 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, ClusteredDistribution} +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +import scala.collection.JavaConversions._ + +/** + * :: DeveloperApi :: + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. + */ +@DeveloperApi +case class ShuffledHashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashOuterJoin { + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + protected override def doExecute(): RDD[InternalRow] = { + val joinedRow = new JoinedRow() + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + joinType match { + case LeftOuter => + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val keyGenerator = newProjection(leftKeys, left.output) + leftIter.flatMap( currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) + }) + + case RightOuter => + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + val keyGenerator = newProjection(rightKeys, right.output) + rightIter.flatMap ( currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + }) + + case FullOuter => + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST), + joinedRow) + } + + case x => + throw new IllegalArgumentException( + s"ShuffledHashOuterJoin should not take $x as the JoinType") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 20390a5544304..8953889d1fae9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -45,9 +45,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j - case j: HashOuterJoin => j + case j: ShuffledHashOuterJoin => j case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j + case j: BroadcastHashOuterJoin => j case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j @@ -81,12 +82,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[HashOuterJoin]), + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[HashOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]), + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", @@ -133,6 +135,34 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ctx.sql("UNCACHE TABLE testData") } + test("broadcasted hash outer join operator selection") { + ctx.cacheManager.clearCache() + ctx.sql("CACHE TABLE testData") + + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + try { + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) + } + + ctx.sql("UNCACHE TABLE testData") + } + test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 13f3be8ca28d6..108b1122f7bff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -54,6 +54,37 @@ class SparkPlanTest extends SparkFunSuite { input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedAnswer: Seq[Row]): Unit = { + checkAnswer(input :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans.head), expectedAnswer) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param left the left input data to be used. + * @param right the right input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer( + left: DataFrame, + right: DataFrame, + planFunction: (SparkPlan, SparkPlan) => SparkPlan, + expectedAnswer: Seq[Row]): Unit = { + checkAnswer(left :: right :: Nil, + (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), expectedAnswer) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer( + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, + expectedAnswer: Seq[Row]): Unit = { SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) case None => @@ -72,11 +103,41 @@ class SparkPlanTest extends SparkFunSuite { planFunction: SparkPlan => SparkPlan, expectedAnswer: Seq[A]): Unit = { val expectedRows = expectedAnswer.map(Row.fromTuple) - SparkPlanTest.checkAnswer(input, planFunction, expectedRows) match { - case Some(errorMessage) => fail(errorMessage) - case None => - } + checkAnswer(input, planFunction, expectedRows) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param left the left input data to be used. + * @param right the right input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + */ + protected def checkAnswer[A <: Product : TypeTag]( + left: DataFrame, + right: DataFrame, + planFunction: (SparkPlan, SparkPlan) => SparkPlan, + expectedAnswer: Seq[A]): Unit = { + val expectedRows = expectedAnswer.map(Row.fromTuple) + checkAnswer(left, right, planFunction, expectedRows) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + */ + protected def checkAnswer[A <: Product : TypeTag]( + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, + expectedAnswer: Seq[A]): Unit = { + val expectedRows = expectedAnswer.map(Row.fromTuple) + checkAnswer(input, planFunction, expectedRows) } + } /** @@ -92,27 +153,25 @@ object SparkPlanTest { * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ def checkAnswer( - input: DataFrame, - planFunction: SparkPlan => SparkPlan, + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row]): Option[String] = { - val outputPlan = planFunction(input.queryExecution.sparkPlan) + val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = outputPlan transform { - case plan: SparkPlan => - val inputMap = plan.children.flatMap(_.output).zipWithIndex.map { - case (a, i) => - (a.name, BoundReference(i, a.dataType, a.nullable)) - }.toMap - - plan.transformExpressions { - case UnresolvedAttribute(Seq(u)) => - inputMap.getOrElse(u, - sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) - } - } + val resolvedPlan = TestSQLContext.prepareForExecution.execute( + outputPlan transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap + plan.transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + ) def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala new file mode 100644 index 0000000000000..5707d2fb300ae --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan} +import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} + +class OuterJoinSuite extends SparkPlanTest { + + val left = Seq( + (1, 2.0), + (2, 1.0), + (3, 3.0) + ).toDF("a", "b") + + val right = Seq( + (2, 3.0), + (3, 2.0), + (4, 1.0) + ).toDF("c", "d") + + val leftKeys: List[Expression] = 'a :: Nil + val rightKeys: List[Expression] = 'c :: Nil + val condition = Some(LessThan('b, 'd)) + + test("shuffled hash outer join") { + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null) + )) + + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), + Seq( + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + )) + + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right), + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + )) + } + + test("broadcast hash outer join") { + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null) + )) + + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), + Seq( + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + )) + } +} From 132e7fca129be8f00ba429a51bcef60abb2eed6d Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 6 Jul 2015 15:54:43 -0700 Subject: [PATCH 010/149] [MINOR] [SQL] remove unused code in Exchange Author: Daoyuan Wang Closes #7234 from adrian-wang/exchangeclean and squashes the following commits: b093ec9 [Daoyuan Wang] remove unused code --- .../org/apache/spark/sql/execution/Exchange.scala | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index edc64a03335d6..e054c1d144e34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -117,20 +117,6 @@ case class Exchange( } } - private val keyOrdering = { - if (newOrdering.nonEmpty) { - val key = newPartitioning.keyExpressions - val boundOrdering = newOrdering.map { o => - val ordinal = key.indexOf(o.child) - if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning") - o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable)) - } - new RowOrdering(boundOrdering) - } else { - null // Ordering will not be used - } - } - @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf private def getSerializer( From 9ff203346ca4decf2999e33bfb8c400ec75313e6 Mon Sep 17 00:00:00 2001 From: Wisely Chen Date: Mon, 6 Jul 2015 16:04:01 -0700 Subject: [PATCH 011/149] [SPARK-8656] [WEBUI] Fix the webUI and JSON API number is not synced Spark standalone master web UI show "Alive Workers" total core, total used cores and "Alive workers" total memory, memory used. But the JSON API page "http://MASTERURL:8088/json" shows "ALL workers" core, memory number. This webUI data is not sync with the JSON API. The proper way is to sync the number with webUI and JSON API. Author: Wisely Chen Closes #7038 from thegiive/SPARK-8656 and squashes the following commits: 9e54bf0 [Wisely Chen] Change variable name to camel case 2c8ea89 [Wisely Chen] Change some styling and add local variable 431d2b0 [Wisely Chen] Worker List should contain DEAD node also 8b3b8e8 [Wisely Chen] [SPARK-8656] Fix the webUI and JSON API number is not synced --- .../scala/org/apache/spark/deploy/JsonProtocol.scala | 9 +++++---- .../org/apache/spark/deploy/master/WorkerInfo.scala | 2 ++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 2954f932b4f41..ccffb36652988 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -76,12 +76,13 @@ private[deploy] object JsonProtocol { } def writeMasterState(obj: MasterStateResponse): JObject = { + val aliveWorkers = obj.workers.filter(_.isAlive()) ("url" -> obj.uri) ~ ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ - ("cores" -> obj.workers.map(_.cores).sum) ~ - ("coresused" -> obj.workers.map(_.coresUsed).sum) ~ - ("memory" -> obj.workers.map(_.memory).sum) ~ - ("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~ + ("cores" -> aliveWorkers.map(_.cores).sum) ~ + ("coresused" -> aliveWorkers.map(_.coresUsed).sum) ~ + ("memory" -> aliveWorkers.map(_.memory).sum) ~ + ("memoryused" -> aliveWorkers.map(_.memoryUsed).sum) ~ ("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~ ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~ ("activedrivers" -> obj.activeDrivers.toList.map(writeDriverInfo)) ~ diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 471811037e5e2..f751966605206 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -105,4 +105,6 @@ private[spark] class WorkerInfo( def setState(state: WorkerState.Value): Unit = { this.state = state } + + def isAlive(): Boolean = this.state == WorkerState.ALIVE } From 1165b17d24cdf1dbebb2faca14308dfe5c2a652c Mon Sep 17 00:00:00 2001 From: Ankur Chauhan Date: Mon, 6 Jul 2015 16:04:57 -0700 Subject: [PATCH 012/149] [SPARK-6707] [CORE] [MESOS] Mesos Scheduler should allow the user to specify constraints based on slave attributes Currently, the mesos scheduler only looks at the 'cpu' and 'mem' resources when trying to determine the usablility of a resource offer from a mesos slave node. It may be preferable for the user to be able to ensure that the spark jobs are only started on a certain set of nodes (based on attributes). For example, If the user sets a property, let's say `spark.mesos.constraints` is set to `tachyon=true;us-east-1=false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. Author: Ankur Chauhan Closes #5563 from ankurcha/mesos_attribs and squashes the following commits: 902535b [Ankur Chauhan] Fix line length d83801c [Ankur Chauhan] Update code as per code review comments 8b73f2d [Ankur Chauhan] Fix imports c3523e7 [Ankur Chauhan] Added docs 1a24d0b [Ankur Chauhan] Expand scope of attributes matching to include all data types 482fd71 [Ankur Chauhan] Update access modifier to private[this] for offer constraints 5ccc32d [Ankur Chauhan] Fix nit pick whitespace 1bce782 [Ankur Chauhan] Fix nit pick whitespace c0cbc75 [Ankur Chauhan] Use offer id value for debug message 7fee0ea [Ankur Chauhan] Add debug statements fc7eb5b [Ankur Chauhan] Fix import codestyle 00be252 [Ankur Chauhan] Style changes as per code review comments 662535f [Ankur Chauhan] Incorporate code review comments + use SparkFunSuite fdc0937 [Ankur Chauhan] Decline offers that did not meet criteria 67b58a0 [Ankur Chauhan] Add documentation for spark.mesos.constraints 63f53f4 [Ankur Chauhan] Update codestyle - uniform style for config values 02031e4 [Ankur Chauhan] Fix scalastyle warnings in tests c09ed84 [Ankur Chauhan] Fixed the access modifier on offerConstraints val to private[mesos] 0c64df6 [Ankur Chauhan] Rename overhead fractions to memory_*, fix spacing 8cc1e8f [Ankur Chauhan] Make exception message more explicit about the source of the error addedba [Ankur Chauhan] Added test case for malformed constraint string ec9d9a6 [Ankur Chauhan] Add tests for parse constraint string 72fe88a [Ankur Chauhan] Fix up tests + remove redundant method override, combine utility class into new mesos scheduler util trait 92b47fd [Ankur Chauhan] Add attributes based constraints support to MesosScheduler --- .../mesos/CoarseMesosSchedulerBackend.scala | 43 +++-- .../scheduler/cluster/mesos/MemoryUtils.scala | 31 ---- .../cluster/mesos/MesosClusterScheduler.scala | 1 + .../cluster/mesos/MesosSchedulerBackend.scala | 62 ++++--- .../cluster/mesos/MesosSchedulerUtils.scala | 153 +++++++++++++++++- .../cluster/mesos/MemoryUtilsSuite.scala | 46 ------ .../mesos/MesosSchedulerBackendSuite.scala | 6 +- .../mesos/MesosSchedulerUtilsSuite.scala | 140 ++++++++++++++++ docs/running-on-mesos.md | 22 +++ 9 files changed, 376 insertions(+), 128 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala delete mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 6b8edca5aa485..b68f8c7685eba 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,18 +18,18 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{Collections, List => JList} +import java.util.{List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils -import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -66,6 +66,10 @@ private[spark] class CoarseMesosSchedulerBackend( val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) + // Offer constraints + private val slaveOfferConstraints = + parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + var nextMesosTaskId = 0 @volatile var appId: String = _ @@ -170,13 +174,16 @@ private[spark] class CoarseMesosSchedulerBackend( override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { synchronized { val filters = Filters.newBuilder().setRefuseSeconds(5).build() - for (offer <- offers) { + val offerAttributes = toAttributeMap(offer.getAttributesList) + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) val slaveId = offer.getSlaveId.toString val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt - if (totalCoresAcquired < maxCores && - mem >= MemoryUtils.calculateTotalMemory(sc) && + val id = offer.getId.getValue + if (meetsConstraints && + totalCoresAcquired < maxCores && + mem >= calculateTotalMemory(sc) && cpus >= 1 && failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && !slaveIdsWithExecutors.contains(slaveId)) { @@ -193,33 +200,25 @@ private[spark] class CoarseMesosSchedulerBackend( .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) .setName("Task " + taskId) .addResources(createResource("cpus", cpusToUse)) - .addResources(createResource("mem", - MemoryUtils.calculateTotalMemory(sc))) + .addResources(createResource("mem", calculateTotalMemory(sc))) sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => MesosSchedulerBackendUtil - .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder()) + .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder) } - d.launchTasks( - Collections.singleton(offer.getId), Collections.singletonList(task.build()), filters) + // accept the offer and launch the task + logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + d.launchTasks(List(offer.getId), List(task.build()), filters) } else { - // Filter it out - d.launchTasks( - Collections.singleton(offer.getId), Collections.emptyList[MesosTaskInfo](), filters) + // Decline the offer + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + d.declineOffer(offer.getId) } } } } - /** Build a Mesos resource protobuf object */ - private def createResource(resourceName: String, quantity: Double): Protos.Resource = { - Resource.newBuilder() - .setName(resourceName) - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) - .build() - } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue.toInt diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala deleted file mode 100644 index 8df4f3b554c41..0000000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster.mesos - -import org.apache.spark.SparkContext - -private[spark] object MemoryUtils { - // These defaults copied from YARN - val OVERHEAD_FRACTION = 0.10 - val OVERHEAD_MINIMUM = 384 - - def calculateTotalMemory(sc: SparkContext): Int = { - sc.conf.getInt("spark.mesos.executor.memoryOverhead", - math.max(OVERHEAD_FRACTION * sc.executorMemory, OVERHEAD_MINIMUM).toInt) + sc.executorMemory - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 1067a7f1caf4c..d3a20f822176e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -29,6 +29,7 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.{Scheduler, SchedulerDriver} + import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 49de85ef48ada..d72e2af456e15 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -23,14 +23,14 @@ import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} import org.apache.mesos.protobuf.ByteString -import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.spark.{SparkContext, SparkException, TaskState} import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils -import org.apache.spark.{SparkContext, SparkException, TaskState} /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a @@ -59,6 +59,10 @@ private[spark] class MesosSchedulerBackend( private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1) + // Offer constraints + private[this] val slaveOfferConstraints = + parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + @volatile var appId: String = _ override def start() { @@ -71,8 +75,8 @@ private[spark] class MesosSchedulerBackend( val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility .getOrElse { - throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") - } + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } val environment = Environment.newBuilder() sc.conf.getOption("spark.executor.extraClassPath").foreach { cp => environment.addVariables( @@ -115,14 +119,14 @@ private[spark] class MesosSchedulerBackend( .setName("cpus") .setType(Value.Type.SCALAR) .setScalar(Value.Scalar.newBuilder() - .setValue(mesosExecutorCores).build()) + .setValue(mesosExecutorCores).build()) .build() val memory = Resource.newBuilder() .setName("mem") .setType(Value.Type.SCALAR) .setScalar( Value.Scalar.newBuilder() - .setValue(MemoryUtils.calculateTotalMemory(sc)).build()) + .setValue(calculateTotalMemory(sc)).build()) .build() val executorInfo = MesosExecutorInfo.newBuilder() .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) @@ -191,13 +195,31 @@ private[spark] class MesosSchedulerBackend( val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue - (mem >= MemoryUtils.calculateTotalMemory(sc) && - // need at least 1 for executor, 1 for task - cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK)) || - (slaveIdsWithExecutors.contains(slaveId) && - cpus >= scheduler.CPUS_PER_TASK) + val offerAttributes = toAttributeMap(o.getAttributesList) + + // check if all constraints are satisfield + // 1. Attribute constraints + // 2. Memory requirements + // 3. CPU requirements - need at least 1 for executor, 1 for task + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) + val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) + + val meetsRequirements = + (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) || + (slaveIdsWithExecutors.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) + + // add some debug messaging + val debugstr = if (meetsRequirements) "Accepting" else "Declining" + val id = o.getId.getValue + logDebug(s"$debugstr offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + + meetsRequirements } + // Decline offers we ruled out immediately + unUsableOffers.foreach(o => d.declineOffer(o.getId)) + val workerOffers = usableOffers.map { o => val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { getResource(o.getResourcesList, "cpus").toInt @@ -223,15 +245,15 @@ private[spark] class MesosSchedulerBackend( val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty) acceptedOffers .foreach { offer => - offer.foreach { taskDesc => - val slaveId = taskDesc.executorId - slaveIdsWithExecutors += slaveId - slavesIdsOfAcceptedOffers += slaveId - taskIdToSlaveId(taskDesc.taskId) = slaveId - mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) - .add(createMesosTask(taskDesc, slaveId)) - } + offer.foreach { taskDesc => + val slaveId = taskDesc.executorId + slaveIdsWithExecutors += slaveId + slavesIdsOfAcceptedOffers += slaveId + taskIdToSlaveId(taskDesc.taskId) = slaveId + mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) + .add(createMesosTask(taskDesc, slaveId)) } + } // Reply to the offers val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? @@ -251,8 +273,6 @@ private[spark] class MesosSchedulerBackend( d.declineOffer(o.getId) } - // Decline offers we ruled out immediately - unUsableOffers.foreach(o => d.declineOffer(o.getId)) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index d11228f3d016a..d8a8c848bb4d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -17,14 +17,17 @@ package org.apache.spark.scheduler.cluster.mesos -import java.util.List +import java.util.{List => JList} import java.util.concurrent.CountDownLatch import scala.collection.JavaConversions._ +import scala.util.control.NonFatal -import org.apache.mesos.Protos.{FrameworkInfo, Resource, Status} -import org.apache.mesos.{MesosSchedulerDriver, Scheduler} -import org.apache.spark.Logging +import com.google.common.base.Splitter +import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler} +import org.apache.mesos.Protos._ +import org.apache.mesos.protobuf.GeneratedMessage +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.util.Utils /** @@ -86,10 +89,150 @@ private[mesos] trait MesosSchedulerUtils extends Logging { /** * Get the amount of resources for the specified type from the resource list */ - protected def getResource(res: List[Resource], name: String): Double = { + protected def getResource(res: JList[Resource], name: String): Double = { for (r <- res if r.getName == name) { return r.getScalar.getValue } 0.0 } + + /** Helper method to get the key,value-set pair for a Mesos Attribute protobuf */ + protected def getAttribute(attr: Attribute): (String, Set[String]) = { + (attr.getName, attr.getText.getValue.split(',').toSet) + } + + + /** Build a Mesos resource protobuf object */ + protected def createResource(resourceName: String, quantity: Double): Protos.Resource = { + Resource.newBuilder() + .setName(resourceName) + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) + .build() + } + + /** + * Converts the attributes from the resource offer into a Map of name -> Attribute Value + * The attribute values are the mesos attribute types and they are + * @param offerAttributes + * @return + */ + protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { + offerAttributes.map(attr => { + val attrValue = attr.getType match { + case Value.Type.SCALAR => attr.getScalar + case Value.Type.RANGES => attr.getRanges + case Value.Type.SET => attr.getSet + case Value.Type.TEXT => attr.getText + } + (attr.getName, attrValue) + }).toMap + } + + + /** + * Match the requirements (if any) to the offer attributes. + * if attribute requirements are not specified - return true + * else if attribute is defined and no values are given, simple attribute presence is performed + * else if attribute name and value is specified, subset match is performed on slave attributes + */ + def matchesAttributeRequirements( + slaveOfferConstraints: Map[String, Set[String]], + offerAttributes: Map[String, GeneratedMessage]): Boolean = { + slaveOfferConstraints.forall { + // offer has the required attribute and subsumes the required values for that attribute + case (name, requiredValues) => + offerAttributes.get(name) match { + case None => false + case Some(_) if requiredValues.isEmpty => true // empty value matches presence + case Some(scalarValue: Value.Scalar) => + // check if provided values is less than equal to the offered values + requiredValues.map(_.toDouble).exists(_ <= scalarValue.getValue) + case Some(rangeValue: Value.Range) => + val offerRange = rangeValue.getBegin to rangeValue.getEnd + // Check if there is some required value that is between the ranges specified + // Note: We only support the ability to specify discrete values, in the future + // we may expand it to subsume ranges specified with a XX..YY value or something + // similar to that. + requiredValues.map(_.toLong).exists(offerRange.contains(_)) + case Some(offeredValue: Value.Set) => + // check if the specified required values is a subset of offered set + requiredValues.subsetOf(offeredValue.getItemList.toSet) + case Some(textValue: Value.Text) => + // check if the specified value is equal, if multiple values are specified + // we succeed if any of them match. + requiredValues.contains(textValue.getValue) + } + } + } + + /** + * Parses the attributes constraints provided to spark and build a matching data struct: + * Map[, Set[values-to-match]] + * The constraints are specified as ';' separated key-value pairs where keys and values + * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for + * multiple values (comma separated). For example: + * {{{ + * parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") + * // would result in + * + * Map( + * "tachyon" -> Set("true"), + * "zone": -> Set("us-east-1a", "us-east-1b") + * ) + * }}} + * + * Mesos documentation: http://mesos.apache.org/documentation/attributes-resources/ + * https://github.com/apache/mesos/blob/master/src/common/values.cpp + * https://github.com/apache/mesos/blob/master/src/common/attributes.cpp + * + * @param constraintsVal constaints string consisting of ';' separated key-value pairs (separated + * by ':') + * @return Map of constraints to match resources offers. + */ + def parseConstraintString(constraintsVal: String): Map[String, Set[String]] = { + /* + Based on mesos docs: + attributes : attribute ( ";" attribute )* + attribute : labelString ":" ( labelString | "," )+ + labelString : [a-zA-Z0-9_/.-] + */ + val splitter = Splitter.on(';').trimResults().withKeyValueSeparator(':') + // kv splitter + if (constraintsVal.isEmpty) { + Map() + } else { + try { + Map() ++ mapAsScalaMap(splitter.split(constraintsVal)).map { + case (k, v) => + if (v == null || v.isEmpty) { + (k, Set[String]()) + } else { + (k, v.split(',').toSet) + } + } + } catch { + case NonFatal(e) => + throw new IllegalArgumentException(s"Bad constraint string: $constraintsVal", e) + } + } + } + + // These defaults copied from YARN + private val MEMORY_OVERHEAD_FRACTION = 0.10 + private val MEMORY_OVERHEAD_MINIMUM = 384 + + /** + * Return the amount of memory to allocate to each executor, taking into account + * container overheads. + * @param sc SparkContext to use to get `spark.mesos.executor.memoryOverhead` value + * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM + * (whichever is larger) + */ + def calculateTotalMemory(sc: SparkContext): Int = { + sc.conf.getInt("spark.mesos.executor.memoryOverhead", + math.max(MEMORY_OVERHEAD_FRACTION * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) + + sc.executorMemory + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala deleted file mode 100644 index e72285d03d3ee..0000000000000 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster.mesos - -import org.mockito.Mockito._ -import org.scalatest.mock.MockitoSugar - -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} - -class MemoryUtilsSuite extends SparkFunSuite with MockitoSugar { - test("MesosMemoryUtils should always override memoryOverhead when it's set") { - val sparkConf = new SparkConf - - val sc = mock[SparkContext] - when(sc.conf).thenReturn(sparkConf) - - // 384 > sc.executorMemory * 0.1 => 512 + 384 = 896 - when(sc.executorMemory).thenReturn(512) - assert(MemoryUtils.calculateTotalMemory(sc) === 896) - - // 384 < sc.executorMemory * 0.1 => 4096 + (4096 * 0.1) = 4505.6 - when(sc.executorMemory).thenReturn(4096) - assert(MemoryUtils.calculateTotalMemory(sc) === 4505) - - // set memoryOverhead - sparkConf.set("spark.mesos.executor.memoryOverhead", "100") - assert(MemoryUtils.calculateTotalMemory(sc) === 4196) - sparkConf.set("spark.mesos.executor.memoryOverhead", "400") - assert(MemoryUtils.calculateTotalMemory(sc) === 4496) - } -} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index 68df46a41ddc8..d01837fe78957 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -149,7 +149,9 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi when(sc.conf).thenReturn(new SparkConf) when(sc.listenerBus).thenReturn(listenerBus) - val minMem = MemoryUtils.calculateTotalMemory(sc).toInt + val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") + + val minMem = backend.calculateTotalMemory(sc) val minCpu = 4 val mesosOffers = new java.util.ArrayList[Offer] @@ -157,8 +159,6 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi mesosOffers.add(createOffer(2, minMem - 1, minCpu)) mesosOffers.add(createOffer(3, minMem, minCpu)) - val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") - val expectedWorkerOffers = new ArrayBuffer[WorkerOffer](2) expectedWorkerOffers.append(new WorkerOffer( mesosOffers.get(0).getSlaveId.getValue, diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala new file mode 100644 index 0000000000000..b354914b6ffd0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.apache.mesos.Protos.Value +import org.mockito.Mockito._ +import org.scalatest._ +import org.scalatest.mock.MockitoSugar +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} + +class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoSugar { + + // scalastyle:off structural.type + // this is the documented way of generating fixtures in scalatest + def fixture: Object {val sc: SparkContext; val sparkConf: SparkConf} = new { + val sparkConf = new SparkConf + val sc = mock[SparkContext] + when(sc.conf).thenReturn(sparkConf) + } + val utils = new MesosSchedulerUtils { } + // scalastyle:on structural.type + + test("use at-least minimum overhead") { + val f = fixture + when(f.sc.executorMemory).thenReturn(512) + utils.calculateTotalMemory(f.sc) shouldBe 896 + } + + test("use overhead if it is greater than minimum value") { + val f = fixture + when(f.sc.executorMemory).thenReturn(4096) + utils.calculateTotalMemory(f.sc) shouldBe 4505 + } + + test("use spark.mesos.executor.memoryOverhead (if set)") { + val f = fixture + when(f.sc.executorMemory).thenReturn(1024) + f.sparkConf.set("spark.mesos.executor.memoryOverhead", "512") + utils.calculateTotalMemory(f.sc) shouldBe 1536 + } + + test("parse a non-empty constraint string correctly") { + val expectedMap = Map( + "tachyon" -> Set("true"), + "zone" -> Set("us-east-1a", "us-east-1b") + ) + utils.parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") should be (expectedMap) + } + + test("parse an empty constraint string correctly") { + utils.parseConstraintString("") shouldBe Map() + } + + test("throw an exception when the input is malformed") { + an[IllegalArgumentException] should be thrownBy + utils.parseConstraintString("tachyon;zone:us-east") + } + + test("empty values for attributes' constraints matches all values") { + val constraintsStr = "tachyon:" + val parsedConstraints = utils.parseConstraintString(constraintsStr) + + parsedConstraints shouldBe Map("tachyon" -> Set()) + + val zoneSet = Value.Set.newBuilder().addItem("us-east-1a").addItem("us-east-1b").build() + val noTachyonOffer = Map("zone" -> zoneSet) + val tachyonTrueOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build()) + val tachyonFalseOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("false").build()) + + utils.matchesAttributeRequirements(parsedConstraints, noTachyonOffer) shouldBe false + utils.matchesAttributeRequirements(parsedConstraints, tachyonTrueOffer) shouldBe true + utils.matchesAttributeRequirements(parsedConstraints, tachyonFalseOffer) shouldBe true + } + + test("subset match is performed for set attributes") { + val supersetConstraint = Map( + "tachyon" -> Value.Text.newBuilder().setValue("true").build(), + "zone" -> Value.Set.newBuilder() + .addItem("us-east-1a") + .addItem("us-east-1b") + .addItem("us-east-1c") + .build()) + + val zoneConstraintStr = "tachyon:;zone:us-east-1a,us-east-1c" + val parsedConstraints = utils.parseConstraintString(zoneConstraintStr) + + utils.matchesAttributeRequirements(parsedConstraints, supersetConstraint) shouldBe true + } + + test("less than equal match is performed on scalar attributes") { + val offerAttribs = Map("gpus" -> Value.Scalar.newBuilder().setValue(3).build()) + + val ltConstraint = utils.parseConstraintString("gpus:2") + val eqConstraint = utils.parseConstraintString("gpus:3") + val gtConstraint = utils.parseConstraintString("gpus:4") + + utils.matchesAttributeRequirements(ltConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(eqConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(gtConstraint, offerAttribs) shouldBe false + } + + test("contains match is performed for range attributes") { + val offerAttribs = Map("ports" -> Value.Range.newBuilder().setBegin(7000).setEnd(8000).build()) + val ltConstraint = utils.parseConstraintString("ports:6000") + val eqConstraint = utils.parseConstraintString("ports:7500") + val gtConstraint = utils.parseConstraintString("ports:8002") + val multiConstraint = utils.parseConstraintString("ports:5000,7500,8300") + + utils.matchesAttributeRequirements(ltConstraint, offerAttribs) shouldBe false + utils.matchesAttributeRequirements(eqConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(gtConstraint, offerAttribs) shouldBe false + utils.matchesAttributeRequirements(multiConstraint, offerAttribs) shouldBe true + } + + test("equality match is performed for text attributes") { + val offerAttribs = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build()) + + val trueConstraint = utils.parseConstraintString("tachyon:true") + val falseConstraint = utils.parseConstraintString("tachyon:false") + + utils.matchesAttributeRequirements(trueConstraint, offerAttribs) shouldBe true + utils.matchesAttributeRequirements(falseConstraint, offerAttribs) shouldBe false + } + +} diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 5f1d6daeb27f0..1f915d8ea1d73 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -184,6 +184,14 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. + +{% highlight scala %} +conf.set("spark.mesos.constraints", "tachyon=true;us-east-1=false") +{% endhighlight %} + +For example, Let's say `spark.mesos.constraints` is set to `tachyon=true;us-east-1=false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. + # Mesos Docker Support Spark can make use of a Mesos Docker containerizer by setting the property `spark.mesos.executor.docker.image` @@ -298,6 +306,20 @@ See the [configuration page](configuration.html) for information on Spark config the final overhead will be this value. + + spark.mesos.constraints + Attribute based constraints to be matched against when accepting resource offers. + + Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. Refer to Mesos Attributes & Resources for more information on attributes. +
    +
  • Scalar constraints are matched with "less than equal" semantics i.e. value in the constraint must be less than or equal to the value in the resource offer.
  • +
  • Range constraints are matched with "contains" semantics i.e. value in the constraint must be within the resource offer's value.
  • +
  • Set constraints are matched with "subset of" semantics i.e. value in the constraint must be a subset of the resource offer's value.
  • +
  • Text constraints are metched with "equality" semantics i.e. value in the constraint must be exactly equal to the resource offer's value.
  • +
  • In case there is no value present as a part of the constraint any offer with the corresponding attribute will be accepted (without value check).
  • +
+ + # Troubleshooting and Debugging From 96c5eeec3970e8b1ebc6ddf5c97a7acc47f539dc Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 6 Jul 2015 16:11:22 -0700 Subject: [PATCH 013/149] Revert "[SPARK-7212] [MLLIB] Add sequence learning flag" This reverts commit 25f574eb9a3cb9b93b7d9194a8ec16e00ce2c036. After speaking to some users and developers, we realized that FP-growth doesn't meet the requirement for frequent sequence mining. PrefixSpan (SPARK-6487) would be the correct algorithm for it. feynmanliang Author: Xiangrui Meng Closes #7240 from mengxr/SPARK-7212.revert and squashes the following commits: 2b3d66b [Xiangrui Meng] Revert "[SPARK-7212] [MLLIB] Add sequence learning flag" --- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 38 +++----------- .../spark/mllib/fpm/FPGrowthSuite.scala | 52 +------------------ python/pyspark/mllib/fpm.py | 4 +- 3 files changed, 12 insertions(+), 82 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index abac08022ea47..efa8459d3cdba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -36,7 +36,7 @@ import org.apache.spark.storage.StorageLevel * :: Experimental :: * * Model trained by [[FPGrowth]], which holds frequent itemsets. - * @param freqItemsets frequent itemsets, which is an RDD of [[FreqItemset]] + * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]] * @tparam Item item type */ @Experimental @@ -62,14 +62,13 @@ class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) ex @Experimental class FPGrowth private ( private var minSupport: Double, - private var numPartitions: Int, - private var ordered: Boolean) extends Logging with Serializable { + private var numPartitions: Int) extends Logging with Serializable { /** * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same - * as the input data, ordered: `false`}. + * as the input data}. */ - def this() = this(0.3, -1, false) + def this() = this(0.3, -1) /** * Sets the minimal support level (default: `0.3`). @@ -87,15 +86,6 @@ class FPGrowth private ( this } - /** - * Indicates whether to mine itemsets (unordered) or sequences (ordered) (default: false, mine - * itemsets). - */ - def setOrdered(ordered: Boolean): this.type = { - this.ordered = ordered - this - } - /** * Computes an FP-Growth model that contains frequent itemsets. * @param data input data set, each element contains a transaction @@ -165,7 +155,7 @@ class FPGrowth private ( .flatMap { case (part, tree) => tree.extract(minCount, x => partitioner.getPartition(x) == part) }.map { case (ranks, count) => - new FreqItemset(ranks.map(i => freqItems(i)).reverse.toArray, count, ordered) + new FreqItemset(ranks.map(i => freqItems(i)).toArray, count) } } @@ -181,12 +171,9 @@ class FPGrowth private ( itemToRank: Map[Item, Int], partitioner: Partitioner): mutable.Map[Int, Array[Int]] = { val output = mutable.Map.empty[Int, Array[Int]] - // Filter the basket by frequent items pattern + // Filter the basket by frequent items pattern and sort their ranks. val filtered = transaction.flatMap(itemToRank.get) - if (!this.ordered) { - ju.Arrays.sort(filtered) - } - // Generate conditional transactions + ju.Arrays.sort(filtered) val n = filtered.length var i = n - 1 while (i >= 0) { @@ -211,18 +198,9 @@ object FPGrowth { * Frequent itemset. * @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead. * @param freq frequency - * @param ordered indicates if items represents an itemset (false) or sequence (true) * @tparam Item item type */ - class FreqItemset[Item](val items: Array[Item], val freq: Long, val ordered: Boolean) - extends Serializable { - - /** - * Auxillary constructor, assumes unordered by default. - */ - def this(items: Array[Item], freq: Long) { - this(items, freq, false) - } + class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable { /** * Returns items in a Java List. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 1a8a1e79f2810..66ae3543ecc4e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { - test("FP-Growth frequent itemsets using String type") { + test("FP-Growth using String type") { val transactions = Seq( "r z h k p", "z y x w v u t s", @@ -38,14 +38,12 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) - .setOrdered(false) .run(rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) - .setOrdered(false) .run(rdd) val freqItemsets3 = model3.freqItemsets.collect().map { itemset => (itemset.items.toSet, itemset.freq) @@ -63,59 +61,17 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) - .setOrdered(false) .run(rdd) assert(model2.freqItemsets.count() === 54) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) - .setOrdered(false) .run(rdd) assert(model1.freqItemsets.count() === 625) } - test("FP-Growth frequent sequences using String type"){ - val transactions = Seq( - "r z h k p", - "z y x w v u t s", - "s x o n r", - "x z y m t s q e", - "z", - "x z y r q t p") - .map(_.split(" ")) - val rdd = sc.parallelize(transactions, 2).cache() - - val fpg = new FPGrowth() - - val model1 = fpg - .setMinSupport(0.5) - .setNumPartitions(2) - .setOrdered(true) - .run(rdd) - - /* - Use the following R code to verify association rules using arulesSequences package. - - data = read_baskets("path", info = c("sequenceID","eventID","SIZE")) - freqItemSeq = cspade(data, parameter = list(support = 0.5)) - resSeq = as(freqItemSeq, "data.frame") - resSeq$support = resSeq$support * length(transactions) - names(resSeq)[names(resSeq) == "support"] = "freq" - resSeq - */ - val expected = Set( - (Seq("r"), 3L), (Seq("s"), 3L), (Seq("t"), 3L), (Seq("x"), 4L), (Seq("y"), 3L), - (Seq("z"), 5L), (Seq("z", "y"), 3L), (Seq("x", "t"), 3L), (Seq("y", "t"), 3L), - (Seq("z", "t"), 3L), (Seq("z", "y", "t"), 3L) - ) - val freqItemseqs1 = model1.freqItemsets.collect().map { itemset => - (itemset.items.toSeq, itemset.freq) - }.toSet - assert(freqItemseqs1 == expected) - } - - test("FP-Growth frequent itemsets using Int type") { + test("FP-Growth using Int type") { val transactions = Seq( "1 2 3", "1 2 3 4", @@ -132,14 +88,12 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) - .setOrdered(false) .run(rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) - .setOrdered(false) .run(rdd) assert(model3.freqItemsets.first().items.getClass === Array(1).getClass, "frequent itemsets should use primitive arrays") @@ -155,14 +109,12 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) - .setOrdered(false) .run(rdd) assert(model2.freqItemsets.count() === 15) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) - .setOrdered(false) .run(rdd) assert(model1.freqItemsets.count() === 65) } diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index b7f00d60069e6..bdc4a132b1b18 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -39,8 +39,8 @@ class FPGrowthModel(JavaModelWrapper): >>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] >>> rdd = sc.parallelize(data, 2) >>> model = FPGrowth.train(rdd, 0.6, 2) - >>> sorted(model.freqItemsets().collect(), key=lambda x: x.items) - [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'a', u'c'], freq=3), ... + >>> sorted(model.freqItemsets().collect()) + [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... """ def freqItemsets(self): From 0effe180f4c2cf37af1012b33b43912bdecaf756 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 6 Jul 2015 16:15:12 -0700 Subject: [PATCH 014/149] [SPARK-8765] [MLLIB] Fix PySpark PowerIterationClustering test issue PySpark PowerIterationClustering test failure due to bad demo data. If the data is small, PowerIterationClustering will behavior indeterministic. Author: Yanbo Liang Closes #7177 from yanboliang/spark-8765 and squashes the following commits: 392ae54 [Yanbo Liang] fix model.assignments output 5ec3f1e [Yanbo Liang] fix PySpark PowerIterationClustering test issue --- python/pyspark/mllib/clustering.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index a3eab635282f6..ed4d78a2c6788 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -282,18 +282,30 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): Model produced by [[PowerIterationClustering]]. - >>> data = [(0, 1, 1.0), (0, 2, 1.0), (1, 3, 1.0), (2, 3, 1.0), - ... (0, 3, 1.0), (1, 2, 1.0), (0, 4, 0.1)] + >>> data = [(0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (1, 3, 1.0), + ... (2, 3, 1.0), (3, 4, 0.1), (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), + ... (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), (10, 11, 1.0), + ... (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)] >>> rdd = sc.parallelize(data, 2) >>> model = PowerIterationClustering.train(rdd, 2, 100) >>> model.k 2 + >>> result = sorted(model.assignments().collect(), key=lambda x: x.id) + >>> result[0].cluster == result[1].cluster == result[2].cluster == result[3].cluster + True + >>> result[4].cluster == result[5].cluster == result[6].cluster == result[7].cluster + True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = PowerIterationClusteringModel.load(sc, path) >>> sameModel.k 2 + >>> result = sorted(model.assignments().collect(), key=lambda x: x.id) + >>> result[0].cluster == result[1].cluster == result[2].cluster == result[3].cluster + True + >>> result[4].cluster == result[5].cluster == result[6].cluster == result[7].cluster + True >>> from shutil import rmtree >>> try: ... rmtree(path) From 7b467cc9348fa910e445ad08914a72f8ed4fc249 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 6 Jul 2015 16:26:31 -0700 Subject: [PATCH 015/149] [SPARK-8588] [SQL] Regression test This PR adds regression test for https://issues.apache.org/jira/browse/SPARK-8588 (fixed by https://github.com/apache/spark/commit/457d07eaa023b44b75344110508f629925eb6247). Author: Yin Huai This patch had conflicts when merged, resolved by Committer: Michael Armbrust Closes #7103 from yhuai/SPARK-8588-test and squashes the following commits: eb5f418 [Yin Huai] Add a query test. c61a173 [Yin Huai] Regression test for SPARK-8588. --- .../analysis/HiveTypeCoercionSuite.scala | 21 +++++++++++++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 16 ++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index b56426617789e..93db33d44eb25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -271,4 +271,25 @@ class HiveTypeCoercionSuite extends PlanTest { Literal(true) ) } + + /** + * There are rules that need to not fire before child expressions get resolved. + * We use this test to make sure those rules do not fire early. + */ + test("make sure rules do not fire early") { + // InConversion + val inConversion = HiveTypeCoercion.InConversion + ruleTest(inConversion, + In(UnresolvedAttribute("a"), Seq(Literal(1))), + In(UnresolvedAttribute("a"), Seq(Literal(1))) + ) + ruleTest(inConversion, + In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))), + In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))) + ) + ruleTest(inConversion, + In(Literal("a"), Seq(Literal(1), Literal("b"))), + In(Literal("a"), Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType))) + ) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 6d645393a6da1..bf9f2ecd51793 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -990,5 +990,21 @@ class SQLQuerySuite extends QueryTest { Timestamp.valueOf("1969-12-31 16:00:00"), String.valueOf("1969-12-31 16:00:00"), Timestamp.valueOf("1970-01-01 00:00:00"))) + + } + + test("SPARK-8588 HiveTypeCoercion.inConversion fires too early") { + val df = + TestHive.createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) + df.toDF("id", "date").registerTempTable("test_SPARK8588") + checkAnswer( + TestHive.sql( + """ + |select id, concat(year(date)) + |from test_SPARK8588 where concat(year(date), ' year') in ('2015 year', '2014 year') + """.stripMargin), + Row(1, "2014") :: Row(2, "2015") :: Nil + ) + TestHive.dropTempTable("test_SPARK8588") } } From 09a06418debc25da0191d98798f7c5016d39be91 Mon Sep 17 00:00:00 2001 From: animesh Date: Mon, 6 Jul 2015 16:39:49 -0700 Subject: [PATCH 016/149] [SPARK-8072] [SQL] Better AnalysisException for writing DataFrame with identically named columns Adding a function checkConstraints which will check for the constraints to be applied on the dataframe / dataframe schema. Function called before storing the dataframe to an external storage. Function added in the corresponding datasource API. cc rxin marmbrus Author: animesh This patch had conflicts when merged, resolved by Committer: Michael Armbrust Closes #7013 from animeshbaranawal/8072 and squashes the following commits: f70dd0e [animesh] Change IO exception to Analysis Exception fd45e1b [animesh] 8072: Fix Style Issues a8a964f [animesh] 8072: Improving on previous commits 3cc4d2c [animesh] Fix Style Issues 1a89115 [animesh] Fix Style Issues 98b4399 [animesh] 8072 : Moved the exception handling to ResolvedDataSource specific to parquet format 7c3d928 [animesh] 8072: Adding check to DataFrameWriter.scala --- .../apache/spark/sql/json/JSONRelation.scala | 31 +++++++++++++++++++ .../apache/spark/sql/parquet/newParquet.scala | 19 +++++++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 24 ++++++++++++++ 3 files changed, 73 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 69bf13e1e5a6a..2361d3bf52d2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -22,6 +22,7 @@ import java.io.IOException import org.apache.hadoop.fs.Path import org.apache.spark.rdd.RDD +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType @@ -37,6 +38,17 @@ private[sql] class DefaultSource parameters.getOrElse("path", sys.error("'path' must be specified for json data.")) } + /** Constraints to be imposed on dataframe to be stored. */ + private def checkConstraints(data: DataFrame): Unit = { + if (data.schema.fieldNames.length != data.schema.fieldNames.distinct.length) { + val duplicateColumns = data.schema.fieldNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => "\"" + x + "\"" + }.mkString(", ") + throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + + s"cannot save to JSON format") + } + } + /** Returns a new base relation with the parameters. */ override def createRelation( sqlContext: SQLContext, @@ -63,6 +75,10 @@ private[sql] class DefaultSource mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = { + // check if dataframe satisfies the constraints + // before moving forward + checkConstraints(data) + val path = checkPath(parameters) val filesystemPath = new Path(path) val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) @@ -130,6 +146,17 @@ private[sql] class JSONRelation( samplingRatio, userSpecifiedSchema)(sqlContext) + /** Constraints to be imposed on dataframe to be stored. */ + private def checkConstraints(data: DataFrame): Unit = { + if (data.schema.fieldNames.length != data.schema.fieldNames.distinct.length) { + val duplicateColumns = data.schema.fieldNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => "\"" + x + "\"" + }.mkString(", ") + throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + + s"cannot save to JSON format") + } + } + private val useJacksonStreamingAPI: Boolean = sqlContext.conf.useJacksonStreamingAPI override val needConversion: Boolean = false @@ -178,6 +205,10 @@ private[sql] class JSONRelation( } override def insert(data: DataFrame, overwrite: Boolean): Unit = { + // check if dataframe satisfies constraints + // before moving forward + checkConstraints(data) + val filesystemPath = path match { case Some(p) => new Path(p) case None => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 5ac3e9a44e6fe..6bc69c6ad0847 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -164,7 +164,24 @@ private[sql] class ParquetRelation2( } } - override def dataSchema: StructType = maybeDataSchema.getOrElse(metadataCache.dataSchema) + /** Constraints on schema of dataframe to be stored. */ + private def checkConstraints(schema: StructType): Unit = { + if (schema.fieldNames.length != schema.fieldNames.distinct.length) { + val duplicateColumns = schema.fieldNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => "\"" + x + "\"" + }.mkString(", ") + throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + + s"cannot save to parquet format") + } + } + + override def dataSchema: StructType = { + val schema = maybeDataSchema.getOrElse(metadataCache.dataSchema) + // check if schema satisfies the constraints + // before moving forward + checkConstraints(schema) + schema + } override private[sql] def refresh(): Unit = { super.refresh() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index afb1cf5f8d1cb..f592a9934d0e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -737,4 +737,28 @@ class DataFrameSuite extends QueryTest { df.col("") df.col("t.``") } + + test("SPARK-8072: Better Exception for Duplicate Columns") { + // only one duplicate column present + val e = intercept[org.apache.spark.sql.AnalysisException] { + val df1 = Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1") + .write.format("parquet").save("temp") + } + assert(e.getMessage.contains("Duplicate column(s)")) + assert(e.getMessage.contains("parquet")) + assert(e.getMessage.contains("column1")) + assert(!e.getMessage.contains("column2")) + + // multiple duplicate columns present + val f = intercept[org.apache.spark.sql.AnalysisException] { + val df2 = Seq((1, 2, 3, 4, 5), (2, 3, 4, 5, 6), (3, 4, 5, 6, 7)) + .toDF("column1", "column2", "column3", "column1", "column3") + .write.format("json").save("temp") + } + assert(f.getMessage.contains("Duplicate column(s)")) + assert(f.getMessage.contains("JSON")) + assert(f.getMessage.contains("column1")) + assert(f.getMessage.contains("column3")) + assert(!f.getMessage.contains("column2")) + } } From d4d6d31db5cc5c69ac369f754b7489f444c9ba2f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 6 Jul 2015 17:16:44 -0700 Subject: [PATCH 017/149] [SPARK-8463][SQL] Use DriverRegistry to load jdbc driver at writing path JIRA: https://issues.apache.org/jira/browse/SPARK-8463 Currently, at the reading path, `DriverRegistry` is used to load needed jdbc driver at executors. However, at the writing path, we also need `DriverRegistry` to load jdbc driver. Author: Liang-Chi Hsieh Closes #6900 from viirya/jdbc_write_driver and squashes the following commits: 16cd04b [Liang-Chi Hsieh] Use DriverRegistry to load jdbc driver at writing path. --- .../main/scala/org/apache/spark/sql/jdbc/jdbc.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index dd8aaf6474895..f7ea852fe7f58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -58,13 +58,12 @@ package object jdbc { * are used. */ def savePartition( - url: String, + getConnection: () => Connection, table: String, iterator: Iterator[Row], rddSchema: StructType, - nullTypes: Array[Int], - properties: Properties): Iterator[Byte] = { - val conn = DriverManager.getConnection(url, properties) + nullTypes: Array[Int]): Iterator[Byte] = { + val conn = getConnection() var committed = false try { conn.setAutoCommit(false) // Everything in the same db transaction. @@ -185,8 +184,10 @@ package object jdbc { } val rddSchema = df.schema + val driver: String = DriverRegistry.getDriverClassName(url) + val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) df.foreachPartition { iterator => - JDBCWriteDetails.savePartition(url, table, iterator, rddSchema, nullTypes, properties) + JDBCWriteDetails.savePartition(getConnection, table, iterator, rddSchema, nullTypes) } } From 9eae5fa642317dd11fc783d832d4cbb7e62db471 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 6 Jul 2015 19:22:30 -0700 Subject: [PATCH 018/149] [SPARK-8819] Fix build for maven 3.3.x This is a workaround for MSHADE-148, which leads to an infinite loop when building Spark with maven 3.3.x. This was originally caused by #6441, which added a bunch of test dependencies on the spark-core test module. Recently, it was revealed by #7193. This patch adds a `-Prelease` profile. If present, it will set `createDependencyReducedPom` to true. The consequences are: - If you are releasing Spark with this profile, you are fine as long as you use maven 3.2.x or before. - If you are releasing Spark without this profile, you will run into SPARK-8781. - If you are not releasing Spark but you are using this profile, you may run into SPARK-8819. - If you are not releasing Spark and you did not include this profile, you are fine. This is all documented in `pom.xml` and tested locally with both versions of maven. Author: Andrew Or Closes #7219 from andrewor14/fix-maven-build and squashes the following commits: 1d37e87 [Andrew Or] Merge branch 'master' of github.com:apache/spark into fix-maven-build 3574ae4 [Andrew Or] Review comments f39199c [Andrew Or] Create a -Prelease profile that flags `createDependencyReducedPom` --- dev/create-release/create-release.sh | 4 ++-- pom.xml | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 54274a83f6d66..cfe2cd4752b3f 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -118,13 +118,13 @@ if [[ ! "$@" =~ --skip-publish ]]; then rm -rf $SPARK_REPO - build/mvn -DskipTests -Pyarn -Phive \ + build/mvn -DskipTests -Pyarn -Phive -Prelease-profile\ -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install ./dev/change-version-to-2.11.sh - build/mvn -DskipTests -Pyarn -Phive \ + build/mvn -DskipTests -Pyarn -Phive -Prelease-profile\ -Dscala-2.11 -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install diff --git a/pom.xml b/pom.xml index ffa96128a3d61..fbcc9152765cf 100644 --- a/pom.xml +++ b/pom.xml @@ -161,6 +161,8 @@ 2.4.4 1.1.1.7 1.1.2 + + false ${java.home} @@ -1440,6 +1442,8 @@ 2.3 false + + ${create.dependency.reduced.pom} @@ -1826,6 +1830,26 @@ + + + release-profile + + + true + + + - release-profile + release false @@ -179,6 +180,8 @@ compile compile compile + test + test + + twttr-repo + Twttr Repository + http://maven.twttr.com + + true + + + false + + spark-1.4-staging @@ -1101,6 +1116,24 @@ ${parquet.version} ${parquet.deps.scope} + + org.apache.parquet + parquet-avro + ${parquet.version} + ${parquet.test.deps.scope} + + + org.apache.parquet + parquet-thrift + ${parquet.version} + ${parquet.test.deps.scope} + + + org.apache.thrift + libthrift + ${thrift.version} + ${thrift.test.deps.scope} + org.apache.flume flume-ng-core diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 41e19fd9cc11e..7346d804632bc 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -62,21 +62,8 @@ object MimaExcludes { "org.apache.spark.ml.classification.LogisticCostFun.this"), // SQL execution is considered private. excludePackage("org.apache.spark.sql.execution"), - // NanoTime and CatalystTimestampConverter is only used inside catalyst, - // not needed anymore - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.timestamp.NanoTime"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.timestamp.NanoTime$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.CatalystTimestampConverter"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.CatalystTimestampConverter$"), - // SPARK-6777 Implements backwards compatibility rules in CatalystSchemaConverter - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetTypeInfo"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetTypeInfo$") + // Parquet support is considered private. + excludePackage("org.apache.spark.sql.parquet") ) ++ Seq( // SPARK-8479 Add numNonzeros and numActives to Matrix. ProblemFilters.exclude[MissingMethodProblem]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 7d00047d08d74..a4c2da8e05f5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.types +import scala.util.Try import scala.util.parsing.combinator.RegexParsers -import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ +import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.DeveloperApi @@ -82,6 +83,9 @@ abstract class DataType extends AbstractDataType { object DataType { + private[sql] def fromString(raw: String): DataType = { + Try(DataType.fromJson(raw)).getOrElse(DataType.fromCaseClassString(raw)) + } def fromJson(json: String): DataType = parseDataType(parse(json)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 3b17566d54d9b..e2d3f53f7d978 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -311,6 +311,11 @@ object StructType extends AbstractDataType { private[sql] override def simpleString: String = "struct" + private[sql] def fromString(raw: String): StructType = DataType.fromString(raw) match { + case t: StructType => t + case _ => throw new RuntimeException(s"Failed parsing StructType: $raw") + } + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) def apply(fields: java.util.List[StructField]): StructType = { diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 8fc16928adbd9..f90099f22d4bd 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -101,9 +101,45 @@ 9.3-1102-jdbc41 test + + org.apache.parquet + parquet-avro + test + + + org.apache.parquet + parquet-thrift + test + + + org.apache.thrift + libthrift + test + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-scala-test-sources + generate-test-sources + + add-test-source + + + + src/test/scala + src/test/gen-java + + + + + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala new file mode 100644 index 0000000000000..0c3d8fdab6bd2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.nio.ByteOrder + +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.parquet.column.Dictionary +import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} +import org.apache.parquet.schema.Type.Repetition +import org.apache.parquet.schema.{GroupType, PrimitiveType, Type} + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A [[ParentContainerUpdater]] is used by a Parquet converter to set converted values to some + * corresponding parent container. For example, a converter for a `StructType` field may set + * converted values to a [[MutableRow]]; or a converter for array elements may append converted + * values to an [[ArrayBuffer]]. + */ +private[parquet] trait ParentContainerUpdater { + def set(value: Any): Unit = () + def setBoolean(value: Boolean): Unit = set(value) + def setByte(value: Byte): Unit = set(value) + def setShort(value: Short): Unit = set(value) + def setInt(value: Int): Unit = set(value) + def setLong(value: Long): Unit = set(value) + def setFloat(value: Float): Unit = set(value) + def setDouble(value: Double): Unit = set(value) +} + +/** A no-op updater used for root converter (who doesn't have a parent). */ +private[parquet] object NoopUpdater extends ParentContainerUpdater + +/** + * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[Row]]s. Since + * any Parquet record is also a struct, this converter can also be used as root converter. + * + * When used as a root converter, [[NoopUpdater]] should be used since root converters don't have + * any "parent" container. + * + * @param parquetType Parquet schema of Parquet records + * @param catalystType Spark SQL schema that corresponds to the Parquet record type + * @param updater An updater which propagates converted field values to the parent container + */ +private[parquet] class CatalystRowConverter( + parquetType: GroupType, + catalystType: StructType, + updater: ParentContainerUpdater) + extends GroupConverter { + + /** + * Updater used together with field converters within a [[CatalystRowConverter]]. It propagates + * converted filed values to the `ordinal`-th cell in `currentRow`. + */ + private final class RowUpdater(row: MutableRow, ordinal: Int) extends ParentContainerUpdater { + override def set(value: Any): Unit = row(ordinal) = value + override def setBoolean(value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(value: Short): Unit = row.setShort(ordinal, value) + override def setInt(value: Int): Unit = row.setInt(ordinal, value) + override def setLong(value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(value: Float): Unit = row.setFloat(ordinal, value) + } + + /** + * Represents the converted row object once an entire Parquet record is converted. + * + * @todo Uses [[UnsafeRow]] for better performance. + */ + val currentRow = new SpecificMutableRow(catalystType.map(_.dataType)) + + // Converters for each field. + private val fieldConverters: Array[Converter] = { + parquetType.getFields.zip(catalystType).zipWithIndex.map { + case ((parquetFieldType, catalystField), ordinal) => + // Converted field value should be set to the `ordinal`-th cell of `currentRow` + newConverter(parquetFieldType, catalystField.dataType, new RowUpdater(currentRow, ordinal)) + }.toArray + } + + override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex) + + override def end(): Unit = updater.set(currentRow) + + override def start(): Unit = { + var i = 0 + while (i < currentRow.length) { + currentRow.setNullAt(i) + i += 1 + } + } + + /** + * Creates a converter for the given Parquet type `parquetType` and Spark SQL data type + * `catalystType`. Converted values are handled by `updater`. + */ + private def newConverter( + parquetType: Type, + catalystType: DataType, + updater: ParentContainerUpdater): Converter = { + + catalystType match { + case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType => + new CatalystPrimitiveConverter(updater) + + case ByteType => + new PrimitiveConverter { + override def addInt(value: Int): Unit = + updater.setByte(value.asInstanceOf[ByteType#InternalType]) + } + + case ShortType => + new PrimitiveConverter { + override def addInt(value: Int): Unit = + updater.setShort(value.asInstanceOf[ShortType#InternalType]) + } + + case t: DecimalType => + new CatalystDecimalConverter(t, updater) + + case StringType => + new CatalystStringConverter(updater) + + case TimestampType => + // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. + new PrimitiveConverter { + // Converts nanosecond timestamps stored as INT96 + override def addBinary(value: Binary): Unit = { + assert( + value.length() == 12, + "Timestamps (with nanoseconds) are expected to be stored in 12-byte long binaries, " + + s"but got a ${value.length()}-byte binary.") + + val buf = value.toByteBuffer.order(ByteOrder.LITTLE_ENDIAN) + val timeOfDayNanos = buf.getLong + val julianDay = buf.getInt + updater.setLong(DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos)) + } + } + + case DateType => + new PrimitiveConverter { + override def addInt(value: Int): Unit = { + // DateType is not specialized in `SpecificMutableRow`, have to box it here. + updater.set(value.asInstanceOf[DateType#InternalType]) + } + } + + case t: ArrayType => + new CatalystArrayConverter(parquetType.asGroupType(), t, updater) + + case t: MapType => + new CatalystMapConverter(parquetType.asGroupType(), t, updater) + + case t: StructType => + new CatalystRowConverter(parquetType.asGroupType(), t, new ParentContainerUpdater { + override def set(value: Any): Unit = updater.set(value.asInstanceOf[Row].copy()) + }) + + case t: UserDefinedType[_] => + val catalystTypeForUDT = t.sqlType + val nullable = parquetType.isRepetition(Repetition.OPTIONAL) + val field = StructField("udt", catalystTypeForUDT, nullable) + val parquetTypeForUDT = new CatalystSchemaConverter().convertField(field) + newConverter(parquetTypeForUDT, catalystTypeForUDT, updater) + + case _ => + throw new RuntimeException( + s"Unable to create Parquet converter for data type ${catalystType.json}") + } + } + + /** + * Parquet converter for Parquet primitive types. Note that not all Spark SQL atomic types + * are handled by this converter. Parquet primitive types are only a subset of those of Spark + * SQL. For example, BYTE, SHORT, and INT in Spark SQL are all covered by INT32 in Parquet. + */ + private final class CatalystPrimitiveConverter(updater: ParentContainerUpdater) + extends PrimitiveConverter { + + override def addBoolean(value: Boolean): Unit = updater.setBoolean(value) + override def addInt(value: Int): Unit = updater.setInt(value) + override def addLong(value: Long): Unit = updater.setLong(value) + override def addFloat(value: Float): Unit = updater.setFloat(value) + override def addDouble(value: Double): Unit = updater.setDouble(value) + override def addBinary(value: Binary): Unit = updater.set(value.getBytes) + } + + /** + * Parquet converter for strings. A dictionary is used to minimize string decoding cost. + */ + private final class CatalystStringConverter(updater: ParentContainerUpdater) + extends PrimitiveConverter { + + private var expandedDictionary: Array[UTF8String] = null + + override def hasDictionarySupport: Boolean = true + + override def setDictionary(dictionary: Dictionary): Unit = { + this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { i => + UTF8String.fromBytes(dictionary.decodeToBinary(i).getBytes) + } + } + + override def addValueFromDictionary(dictionaryId: Int): Unit = { + updater.set(expandedDictionary(dictionaryId)) + } + + override def addBinary(value: Binary): Unit = { + updater.set(UTF8String.fromBytes(value.getBytes)) + } + } + + /** + * Parquet converter for fixed-precision decimals. + */ + private final class CatalystDecimalConverter( + decimalType: DecimalType, + updater: ParentContainerUpdater) + extends PrimitiveConverter { + + // Converts decimals stored as INT32 + override def addInt(value: Int): Unit = { + addLong(value: Long) + } + + // Converts decimals stored as INT64 + override def addLong(value: Long): Unit = { + updater.set(Decimal(value, decimalType.precision, decimalType.scale)) + } + + // Converts decimals stored as either FIXED_LENGTH_BYTE_ARRAY or BINARY + override def addBinary(value: Binary): Unit = { + updater.set(toDecimal(value)) + } + + private def toDecimal(value: Binary): Decimal = { + val precision = decimalType.precision + val scale = decimalType.scale + val bytes = value.getBytes + + var unscaled = 0L + var i = 0 + + while (i < bytes.length) { + unscaled = (unscaled << 8) | (bytes(i) & 0xff) + i += 1 + } + + val bits = 8 * bytes.length + unscaled = (unscaled << (64 - bits)) >> (64 - bits) + Decimal(unscaled, precision, scale) + } + } + + /** + * Parquet converter for arrays. Spark SQL arrays are represented as Parquet lists. Standard + * Parquet lists are represented as a 3-level group annotated by `LIST`: + * {{{ + * group (LIST) { <-- parquetSchema points here + * repeated group list { + * element; + * } + * } + * }}} + * The `parquetSchema` constructor argument points to the outermost group. + * + * However, before this representation is standardized, some Parquet libraries/tools also use some + * non-standard formats to represent list-like structures. Backwards-compatibility rules for + * handling these cases are described in Parquet format spec. + * + * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists + */ + private final class CatalystArrayConverter( + parquetSchema: GroupType, + catalystSchema: ArrayType, + updater: ParentContainerUpdater) + extends GroupConverter { + + private var currentArray: ArrayBuffer[Any] = _ + + private val elementConverter: Converter = { + val repeatedType = parquetSchema.getType(0) + val elementType = catalystSchema.elementType + + if (isElementType(repeatedType, elementType)) { + newConverter(repeatedType, elementType, new ParentContainerUpdater { + override def set(value: Any): Unit = currentArray += value + }) + } else { + new ElementConverter(repeatedType.asGroupType().getType(0), elementType) + } + } + + override def getConverter(fieldIndex: Int): Converter = elementConverter + + override def end(): Unit = updater.set(currentArray) + + // NOTE: We can't reuse the mutable `ArrayBuffer` here and must instantiate a new buffer for the + // next value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored + // in row cells. + override def start(): Unit = currentArray = ArrayBuffer.empty[Any] + + // scalastyle:off + /** + * Returns whether the given type is the element type of a list or is a syntactic group with + * one field that is the element type. This is determined by checking whether the type can be + * a syntactic group and by checking whether a potential syntactic group matches the expected + * schema. + * {{{ + * group (LIST) { + * repeated group list { <-- repeatedType points here + * element; + * } + * } + * }}} + * In short, here we handle Parquet list backwards-compatibility rules on the read path. This + * method is based on `AvroIndexedRecordConverter.isElementType`. + * + * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules + */ + // scalastyle:on + private def isElementType(parquetRepeatedType: Type, catalystElementType: DataType): Boolean = { + (parquetRepeatedType, catalystElementType) match { + case (t: PrimitiveType, _) => true + case (t: GroupType, _) if t.getFieldCount > 1 => true + case (t: GroupType, StructType(Array(f))) if f.name == t.getFieldName(0) => true + case _ => false + } + } + + /** Array element converter */ + private final class ElementConverter(parquetType: Type, catalystType: DataType) + extends GroupConverter { + + private var currentElement: Any = _ + + private val converter = newConverter(parquetType, catalystType, new ParentContainerUpdater { + override def set(value: Any): Unit = currentElement = value + }) + + override def getConverter(fieldIndex: Int): Converter = converter + + override def end(): Unit = currentArray += currentElement + + override def start(): Unit = currentElement = null + } + } + + /** Parquet converter for maps */ + private final class CatalystMapConverter( + parquetType: GroupType, + catalystType: MapType, + updater: ParentContainerUpdater) + extends GroupConverter { + + private var currentMap: mutable.Map[Any, Any] = _ + + private val keyValueConverter = { + val repeatedType = parquetType.getType(0).asGroupType() + new KeyValueConverter( + repeatedType.getType(0), + repeatedType.getType(1), + catalystType.keyType, + catalystType.valueType) + } + + override def getConverter(fieldIndex: Int): Converter = keyValueConverter + + override def end(): Unit = updater.set(currentMap) + + // NOTE: We can't reuse the mutable Map here and must instantiate a new `Map` for the next + // value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored in row + // cells. + override def start(): Unit = currentMap = mutable.Map.empty[Any, Any] + + /** Parquet converter for key-value pairs within the map. */ + private final class KeyValueConverter( + parquetKeyType: Type, + parquetValueType: Type, + catalystKeyType: DataType, + catalystValueType: DataType) + extends GroupConverter { + + private var currentKey: Any = _ + + private var currentValue: Any = _ + + private val converters = Array( + // Converter for keys + newConverter(parquetKeyType, catalystKeyType, new ParentContainerUpdater { + override def set(value: Any): Unit = currentKey = value + }), + + // Converter for values + newConverter(parquetValueType, catalystValueType, new ParentContainerUpdater { + override def set(value: Any): Unit = currentValue = value + })) + + override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + + override def end(): Unit = currentMap(currentKey) = currentValue + + override def start(): Unit = { + currentKey = null + currentValue = null + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index 4ab274ec17a02..de3a72d8146c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -358,9 +358,24 @@ private[parquet] class CatalystSchemaConverter( case DateType => Types.primitive(INT32, repetition).as(DATE).named(field.name) - // NOTE: !! This timestamp type is not specified in Parquet format spec !! - // However, Impala and older versions of Spark SQL use INT96 to store timestamps with - // nanosecond precision (not TIME_MILLIS or TIMESTAMP_MILLIS described in the spec). + // NOTE: Spark SQL TimestampType is NOT a well defined type in Parquet format spec. + // + // As stated in PARQUET-323, Parquet `INT96` was originally introduced to represent nanosecond + // timestamp in Impala for some historical reasons, it's not recommended to be used for any + // other types and will probably be deprecated in future Parquet format spec. That's the + // reason why Parquet format spec only defines `TIMESTAMP_MILLIS` and `TIMESTAMP_MICROS` which + // are both logical types annotating `INT64`. + // + // Originally, Spark SQL uses the same nanosecond timestamp type as Impala and Hive. Starting + // from Spark 1.5.0, we resort to a timestamp type with 100 ns precision so that we can store + // a timestamp into a `Long`. This design decision is subject to change though, for example, + // we may resort to microsecond precision in the future. + // + // For Parquet, we plan to write all `TimestampType` value as `TIMESTAMP_MICROS`, but it's + // currently not implemented yet because parquet-mr 1.7.0 (the version we're currently using) + // hasn't implemented `TIMESTAMP_MICROS` yet. + // + // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. case TimestampType => Types.primitive(INT96, repetition).named(field.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 86a77bf965daa..be0a2029d233b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -17,61 +17,15 @@ package org.apache.spark.sql.parquet -import java.nio.ByteOrder - -import scala.collection.mutable.{ArrayBuffer, Buffer, HashMap} - -import org.apache.parquet.Preconditions -import org.apache.parquet.column.Dictionary -import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} -import org.apache.parquet.schema.MessageType - import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.parquet.CatalystConverter.FieldType -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -/** - * Collection of converters of Parquet types (group and primitive types) that - * model arrays and maps. The conversions are partly based on the AvroParquet - * converters that are part of Parquet in order to be able to process these - * types. - * - * There are several types of converters: - *
    - *
  • [[org.apache.spark.sql.parquet.CatalystPrimitiveConverter]] for primitive - * (numeric, boolean and String) types
  • - *
  • [[org.apache.spark.sql.parquet.CatalystNativeArrayConverter]] for arrays - * of native JVM element types; note: currently null values are not supported!
  • - *
  • [[org.apache.spark.sql.parquet.CatalystArrayConverter]] for arrays of - * arbitrary element types (including nested element types); note: currently - * null values are not supported!
  • - *
  • [[org.apache.spark.sql.parquet.CatalystStructConverter]] for structs
  • - *
  • [[org.apache.spark.sql.parquet.CatalystMapConverter]] for maps; note: - * currently null values are not supported!
  • - *
  • [[org.apache.spark.sql.parquet.CatalystPrimitiveRowConverter]] for rows - * of only primitive element types
  • - *
  • [[org.apache.spark.sql.parquet.CatalystGroupConverter]] for other nested - * records, including the top-level row record
  • - *
- */ private[sql] object CatalystConverter { - // The type internally used for fields - type FieldType = StructField - // This is mostly Parquet convention (see, e.g., `ConversionPatterns`). // Note that "array" for the array elements is chosen by ParquetAvro. // Using a different value will result in Parquet silently dropping columns. val ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME = "bag" val ARRAY_ELEMENTS_SCHEMA_NAME = "array" - // SPARK-4520: Thrift generated parquet files have different array element - // schema names than avro. Thrift parquet uses array_schema_name + "_tuple" - // as opposed to "array" used by default. For more information, check - // TestThriftSchemaConverter.java in parquet.thrift. - val THRIFT_ARRAY_ELEMENTS_SCHEMA_NAME_SUFFIX = "_tuple" + val MAP_KEY_SCHEMA_NAME = "key" val MAP_VALUE_SCHEMA_NAME = "value" val MAP_SCHEMA_NAME = "map" @@ -80,787 +34,4 @@ private[sql] object CatalystConverter { type ArrayScalaType[T] = Seq[T] type StructScalaType[T] = InternalRow type MapScalaType[K, V] = Map[K, V] - - protected[parquet] def createConverter( - field: FieldType, - fieldIndex: Int, - parent: CatalystConverter): Converter = { - val fieldType: DataType = field.dataType - fieldType match { - case udt: UserDefinedType[_] => { - createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent) - } - // For native JVM types we use a converter with native arrays - case ArrayType(elementType: AtomicType, false) => { - new CatalystNativeArrayConverter(elementType, fieldIndex, parent) - } - // This is for other types of arrays, including those with nested fields - case ArrayType(elementType: DataType, false) => { - new CatalystArrayConverter(elementType, fieldIndex, parent) - } - case ArrayType(elementType: DataType, true) => { - new CatalystArrayContainsNullConverter(elementType, fieldIndex, parent) - } - case StructType(fields: Array[StructField]) => { - new CatalystStructConverter(fields, fieldIndex, parent) - } - case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) => { - new CatalystMapConverter( - Array( - new FieldType(MAP_KEY_SCHEMA_NAME, keyType, false), - new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, valueContainsNull)), - fieldIndex, - parent) - } - // Strings, Shorts and Bytes do not have a corresponding type in Parquet - // so we need to treat them separately - case StringType => - new CatalystPrimitiveStringConverter(parent, fieldIndex) - case ShortType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addInt(value: Int): Unit = - parent.updateShort(fieldIndex, value.asInstanceOf[ShortType.InternalType]) - } - } - case ByteType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addInt(value: Int): Unit = - parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.InternalType]) - } - } - case DateType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addInt(value: Int): Unit = - parent.updateDate(fieldIndex, value.asInstanceOf[DateType.InternalType]) - } - } - case d: DecimalType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addBinary(value: Binary): Unit = - parent.updateDecimal(fieldIndex, value, d) - } - } - case TimestampType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addBinary(value: Binary): Unit = - parent.updateTimestamp(fieldIndex, value) - } - } - // All other primitive types use the default converter - case ctype: DataType if ParquetTypesConverter.isPrimitiveType(ctype) => { - // note: need the type tag here! - new CatalystPrimitiveConverter(parent, fieldIndex) - } - case _ => throw new RuntimeException( - s"unable to convert datatype ${field.dataType.toString} in CatalystConverter") - } - } - - protected[parquet] def createRootConverter( - parquetSchema: MessageType, - attributes: Seq[Attribute]): CatalystConverter = { - // For non-nested types we use the optimized Row converter - if (attributes.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType))) { - new CatalystPrimitiveRowConverter(attributes.toArray) - } else { - new CatalystGroupConverter(attributes.toArray) - } - } -} - -private[parquet] abstract class CatalystConverter extends GroupConverter { - /** - * The number of fields this group has - */ - protected[parquet] val size: Int - - /** - * The index of this converter in the parent - */ - protected[parquet] val index: Int - - /** - * The parent converter - */ - protected[parquet] val parent: CatalystConverter - - /** - * Called by child converters to update their value in its parent (this). - * Note that if possible the more specific update methods below should be used - * to avoid auto-boxing of native JVM types. - * - * @param fieldIndex - * @param value - */ - protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit - - protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = - updateField(fieldIndex, value) - - protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = - updateField(fieldIndex, value.getBytes) - - protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = - updateField(fieldIndex, UTF8String.fromBytes(value)) - - protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = - updateField(fieldIndex, readTimestamp(value)) - - protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = - updateField(fieldIndex, readDecimal(new Decimal(), value, ctype)) - - protected[parquet] def isRootConverter: Boolean = parent == null - - protected[parquet] def clearBuffer(): Unit - - /** - * Should only be called in the root (group) converter! - * - * @return - */ - def getCurrentRecord: InternalRow = throw new UnsupportedOperationException - - /** - * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in - * a long (i.e. precision <= 18) - * - * Returned value is needed by CatalystConverter, which doesn't reuse the Decimal object. - */ - protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Decimal = { - val precision = ctype.precisionInfo.get.precision - val scale = ctype.precisionInfo.get.scale - val bytes = value.getBytes - require(bytes.length <= 16, "Decimal field too large to read") - var unscaled = 0L - var i = 0 - while (i < bytes.length) { - unscaled = (unscaled << 8) | (bytes(i) & 0xFF) - i += 1 - } - // Make sure unscaled has the right sign, by sign-extending the first bit - val numBits = 8 * bytes.length - unscaled = (unscaled << (64 - numBits)) >> (64 - numBits) - dest.set(unscaled, precision, scale) - } - - /** - * Read a Timestamp value from a Parquet Int96Value - */ - protected[parquet] def readTimestamp(value: Binary): Long = { - Preconditions.checkArgument(value.length() == 12, "Must be 12 bytes") - val buf = value.toByteBuffer - buf.order(ByteOrder.LITTLE_ENDIAN) - val timeOfDayNanos = buf.getLong - val julianDay = buf.getInt - DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos) - } -} - -/** - * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record - * to a [[org.apache.spark.sql.catalyst.expressions.InternalRow]] object. - * - * @param schema The corresponding Catalyst schema in the form of a list of attributes. - */ -private[parquet] class CatalystGroupConverter( - protected[parquet] val schema: Array[FieldType], - protected[parquet] val index: Int, - protected[parquet] val parent: CatalystConverter, - protected[parquet] var current: ArrayBuffer[Any], - protected[parquet] var buffer: ArrayBuffer[InternalRow]) - extends CatalystConverter { - - def this(schema: Array[FieldType], index: Int, parent: CatalystConverter) = - this( - schema, - index, - parent, - current = null, - buffer = new ArrayBuffer[InternalRow]( - CatalystArrayConverter.INITIAL_ARRAY_SIZE)) - - /** - * This constructor is used for the root converter only! - */ - def this(attributes: Array[Attribute]) = - this(attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), 0, null) - - protected [parquet] val converters: Array[Converter] = - schema.zipWithIndex.map { - case (field, idx) => CatalystConverter.createConverter(field, idx, this) - }.toArray - - override val size = schema.size - - override def getCurrentRecord: InternalRow = { - assert(isRootConverter, "getCurrentRecord should only be called in root group converter!") - // TODO: use iterators if possible - // Note: this will ever only be called in the root converter when the record has been - // fully processed. Therefore it will be difficult to use mutable rows instead, since - // any non-root converter never would be sure when it would be safe to re-use the buffer. - new GenericInternalRow(current.toArray) - } - - override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) - - // for child converters to update upstream values - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - current.update(fieldIndex, value) - } - - override protected[parquet] def clearBuffer(): Unit = buffer.clear() - - override def start(): Unit = { - current = ArrayBuffer.fill(size)(null) - converters.foreach { converter => - if (!converter.isPrimitive) { - converter.asInstanceOf[CatalystConverter].clearBuffer() - } - } - } - - override def end(): Unit = { - if (!isRootConverter) { - assert(current != null) // there should be no empty groups - buffer.append(new GenericInternalRow(current.toArray)) - parent.updateField(index, new GenericInternalRow(buffer.toArray.asInstanceOf[Array[Any]])) - } - } -} - -/** - * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record - * to a [[org.apache.spark.sql.catalyst.expressions.InternalRow]] object. Note that his - * converter is optimized for rows of primitive types (non-nested records). - */ -private[parquet] class CatalystPrimitiveRowConverter( - protected[parquet] val schema: Array[FieldType], - protected[parquet] var current: MutableRow) - extends CatalystConverter { - - // This constructor is used for the root converter only - def this(attributes: Array[Attribute]) = - this( - attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), - new SpecificMutableRow(attributes.map(_.dataType))) - - protected [parquet] val converters: Array[Converter] = - schema.zipWithIndex.map { - case (field, idx) => CatalystConverter.createConverter(field, idx, this) - }.toArray - - override val size = schema.size - - override val index = 0 - - override val parent = null - - // Should be only called in root group converter! - override def getCurrentRecord: InternalRow = current - - override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) - - // for child converters to update upstream values - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - throw new UnsupportedOperationException // child converters should use the - // specific update methods below - } - - override protected[parquet] def clearBuffer(): Unit = {} - - override def start(): Unit = { - var i = 0 - while (i < size) { - current.setNullAt(i) - i = i + 1 - } - } - - override def end(): Unit = {} - - // Overridden here to avoid auto-boxing for primitive types - override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = - current.setBoolean(fieldIndex, value) - - override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = - current.setInt(fieldIndex, value) - - override protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = - current.setInt(fieldIndex, value) - - override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = - current.setLong(fieldIndex, value) - - override protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = - current.setShort(fieldIndex, value) - - override protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = - current.setByte(fieldIndex, value) - - override protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = - current.setDouble(fieldIndex, value) - - override protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = - current.setFloat(fieldIndex, value) - - override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = - current.update(fieldIndex, value.getBytes) - - override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = - current.update(fieldIndex, UTF8String.fromBytes(value)) - - override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = - current.setLong(fieldIndex, readTimestamp(value)) - - override protected[parquet] def updateDecimal( - fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { - var decimal = current(fieldIndex).asInstanceOf[Decimal] - if (decimal == null) { - decimal = new Decimal - current(fieldIndex) = decimal - } - readDecimal(decimal, value, ctype) - } -} - -/** - * A `parquet.io.api.PrimitiveConverter` that converts Parquet types to Catalyst types. - * - * @param parent The parent group converter. - * @param fieldIndex The index inside the record. - */ -private[parquet] class CatalystPrimitiveConverter( - parent: CatalystConverter, - fieldIndex: Int) extends PrimitiveConverter { - override def addBinary(value: Binary): Unit = - parent.updateBinary(fieldIndex, value) - - override def addBoolean(value: Boolean): Unit = - parent.updateBoolean(fieldIndex, value) - - override def addDouble(value: Double): Unit = - parent.updateDouble(fieldIndex, value) - - override def addFloat(value: Float): Unit = - parent.updateFloat(fieldIndex, value) - - override def addInt(value: Int): Unit = - parent.updateInt(fieldIndex, value) - - override def addLong(value: Long): Unit = - parent.updateLong(fieldIndex, value) -} - -/** - * A `parquet.io.api.PrimitiveConverter` that converts Parquet Binary to Catalyst String. - * Supports dictionaries to reduce Binary to String conversion overhead. - * - * Follows pattern in Parquet of using dictionaries, where supported, for String conversion. - * - * @param parent The parent group converter. - * @param fieldIndex The index inside the record. - */ -private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int) - extends CatalystPrimitiveConverter(parent, fieldIndex) { - - private[this] var dict: Array[Array[Byte]] = null - - override def hasDictionarySupport: Boolean = true - - override def setDictionary(dictionary: Dictionary): Unit = - dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes } - - override def addValueFromDictionary(dictionaryId: Int): Unit = - parent.updateString(fieldIndex, dict(dictionaryId)) - - override def addBinary(value: Binary): Unit = - parent.updateString(fieldIndex, value.getBytes) -} - -private[parquet] object CatalystArrayConverter { - val INITIAL_ARRAY_SIZE = 20 -} - -/** - * A `parquet.io.api.GroupConverter` that converts a single-element groups that - * match the characteristics of an array (see - * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.types.ArrayType]]. - * - * @param elementType The type of the array elements (complex or primitive) - * @param index The position of this (array) field inside its parent converter - * @param parent The parent converter - * @param buffer A data buffer - */ -private[parquet] class CatalystArrayConverter( - val elementType: DataType, - val index: Int, - protected[parquet] val parent: CatalystConverter, - protected[parquet] var buffer: Buffer[Any]) - extends CatalystConverter { - - def this(elementType: DataType, index: Int, parent: CatalystConverter) = - this( - elementType, - index, - parent, - new ArrayBuffer[Any](CatalystArrayConverter.INITIAL_ARRAY_SIZE)) - - protected[parquet] val converter: Converter = CatalystConverter.createConverter( - new CatalystConverter.FieldType( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - elementType, - false), - fieldIndex = 0, - parent = this) - - override def getConverter(fieldIndex: Int): Converter = converter - - // arrays have only one (repeated) field, which is its elements - override val size = 1 - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - // fieldIndex is ignored (assumed to be zero but not checked) - if (value == null) { - throw new IllegalArgumentException("Null values inside Parquet arrays are not supported!") - } - buffer += value - } - - override protected[parquet] def clearBuffer(): Unit = { - buffer.clear() - } - - override def start(): Unit = { - if (!converter.isPrimitive) { - converter.asInstanceOf[CatalystConverter].clearBuffer() - } - } - - override def end(): Unit = { - assert(parent != null) - // here we need to make sure to use ArrayScalaType - parent.updateField(index, buffer.toArray.toSeq) - clearBuffer() - } -} - -/** - * A `parquet.io.api.GroupConverter` that converts a single-element groups that - * match the characteristics of an array (see - * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.types.ArrayType]]. - * - * @param elementType The type of the array elements (native) - * @param index The position of this (array) field inside its parent converter - * @param parent The parent converter - * @param capacity The (initial) capacity of the buffer - */ -private[parquet] class CatalystNativeArrayConverter( - val elementType: AtomicType, - val index: Int, - protected[parquet] val parent: CatalystConverter, - protected[parquet] var capacity: Int = CatalystArrayConverter.INITIAL_ARRAY_SIZE) - extends CatalystConverter { - - type NativeType = elementType.InternalType - - private var buffer: Array[NativeType] = elementType.classTag.newArray(capacity) - - private var elements: Int = 0 - - protected[parquet] val converter: Converter = CatalystConverter.createConverter( - new CatalystConverter.FieldType( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - elementType, - false), - fieldIndex = 0, - parent = this) - - override def getConverter(fieldIndex: Int): Converter = converter - - // arrays have only one (repeated) field, which is its elements - override val size = 1 - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = - throw new UnsupportedOperationException - - // Overridden here to avoid auto-boxing for primitive types - override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = { - checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = { - checkGrowBuffer() - buffer(elements) = value.getBytes.asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = { - checkGrowBuffer() - buffer(elements) = UTF8String.fromBytes(value).asInstanceOf[NativeType] - elements += 1 - } - - override protected[parquet] def clearBuffer(): Unit = { - elements = 0 - } - - override def start(): Unit = {} - - override def end(): Unit = { - assert(parent != null) - // here we need to make sure to use ArrayScalaType - parent.updateField( - index, - buffer.slice(0, elements).toSeq) - clearBuffer() - } - - private def checkGrowBuffer(): Unit = { - if (elements >= capacity) { - val newCapacity = 2 * capacity - val tmp: Array[NativeType] = elementType.classTag.newArray(newCapacity) - Array.copy(buffer, 0, tmp, 0, capacity) - buffer = tmp - capacity = newCapacity - } - } -} - -/** - * A `parquet.io.api.GroupConverter` that converts a single-element groups that - * match the characteristics of an array contains null (see - * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.types.ArrayType]]. - * - * @param elementType The type of the array elements (complex or primitive) - * @param index The position of this (array) field inside its parent converter - * @param parent The parent converter - * @param buffer A data buffer - */ -private[parquet] class CatalystArrayContainsNullConverter( - val elementType: DataType, - val index: Int, - protected[parquet] val parent: CatalystConverter, - protected[parquet] var buffer: Buffer[Any]) - extends CatalystConverter { - - def this(elementType: DataType, index: Int, parent: CatalystConverter) = - this( - elementType, - index, - parent, - new ArrayBuffer[Any](CatalystArrayConverter.INITIAL_ARRAY_SIZE)) - - protected[parquet] val converter: Converter = new CatalystConverter { - - private var current: Any = null - - val converter = CatalystConverter.createConverter( - new CatalystConverter.FieldType( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - elementType, - false), - fieldIndex = 0, - parent = this) - - override def getConverter(fieldIndex: Int): Converter = converter - - override def end(): Unit = parent.updateField(index, current) - - override def start(): Unit = { - current = null - } - - override protected[parquet] val size: Int = 1 - override protected[parquet] val index: Int = 0 - override protected[parquet] val parent = CatalystArrayContainsNullConverter.this - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - current = value - } - - override protected[parquet] def clearBuffer(): Unit = {} - } - - override def getConverter(fieldIndex: Int): Converter = converter - - // arrays have only one (repeated) field, which is its elements - override val size = 1 - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - buffer += value - } - - override protected[parquet] def clearBuffer(): Unit = { - buffer.clear() - } - - override def start(): Unit = {} - - override def end(): Unit = { - assert(parent != null) - // here we need to make sure to use ArrayScalaType - parent.updateField(index, buffer.toArray.toSeq) - clearBuffer() - } -} - -/** - * This converter is for multi-element groups of primitive or complex types - * that have repetition level optional or required (so struct fields). - * - * @param schema The corresponding Catalyst schema in the form of a list of - * attributes. - * @param index - * @param parent - */ -private[parquet] class CatalystStructConverter( - override protected[parquet] val schema: Array[FieldType], - override protected[parquet] val index: Int, - override protected[parquet] val parent: CatalystConverter) - extends CatalystGroupConverter(schema, index, parent) { - - override protected[parquet] def clearBuffer(): Unit = {} - - // TODO: think about reusing the buffer - override def end(): Unit = { - assert(!isRootConverter) - // here we need to make sure to use StructScalaType - // Note: we need to actually make a copy of the array since we - // may be in a nested field - parent.updateField(index, new GenericInternalRow(current.toArray)) - } -} - -/** - * A `parquet.io.api.GroupConverter` that converts two-element groups that - * match the characteristics of a map (see - * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.types.MapType]]. - * - * @param schema - * @param index - * @param parent - */ -private[parquet] class CatalystMapConverter( - protected[parquet] val schema: Array[FieldType], - override protected[parquet] val index: Int, - override protected[parquet] val parent: CatalystConverter) - extends CatalystConverter { - - private val map = new HashMap[Any, Any]() - - private val keyValueConverter = new CatalystConverter { - private var currentKey: Any = null - private var currentValue: Any = null - val keyConverter = CatalystConverter.createConverter(schema(0), 0, this) - val valueConverter = CatalystConverter.createConverter(schema(1), 1, this) - - override def getConverter(fieldIndex: Int): Converter = { - if (fieldIndex == 0) keyConverter else valueConverter - } - - override def end(): Unit = CatalystMapConverter.this.map += currentKey -> currentValue - - override def start(): Unit = { - currentKey = null - currentValue = null - } - - override protected[parquet] val size: Int = 2 - override protected[parquet] val index: Int = 0 - override protected[parquet] val parent: CatalystConverter = CatalystMapConverter.this - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { - fieldIndex match { - case 0 => - currentKey = value - case 1 => - currentValue = value - case _ => - new RuntimePermission(s"trying to update Map with fieldIndex $fieldIndex") - } - } - - override protected[parquet] def clearBuffer(): Unit = {} - } - - override protected[parquet] val size: Int = 1 - - override protected[parquet] def clearBuffer(): Unit = {} - - override def start(): Unit = { - map.clear() - } - - override def end(): Unit = { - // here we need to make sure to use MapScalaType - parent.updateField(index, map.toMap) - } - - override def getConverter(fieldIndex: Int): Converter = keyValueConverter - - override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = - throw new UnsupportedOperationException } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 8402cd756140d..e8851ddb68026 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql.parquet -import java.nio.{ByteOrder, ByteBuffer} +import java.nio.{ByteBuffer, ByteOrder} +import java.util import java.util.{HashMap => JHashMap} +import scala.collection.JavaConversions._ + import org.apache.hadoop.conf.Configuration import org.apache.parquet.column.ParquetProperties import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.parquet.hadoop.api.ReadSupport.ReadContext -import org.apache.parquet.hadoop.api.{ReadSupport, WriteSupport} +import org.apache.parquet.hadoop.api.{InitContext, ReadSupport, WriteSupport} import org.apache.parquet.io.api._ import org.apache.parquet.schema.MessageType @@ -36,87 +39,133 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** - * A `parquet.io.api.RecordMaterializer` for Rows. + * A [[RecordMaterializer]] for Catalyst rows. * - *@param root The root group converter for the record. + * @param parquetSchema Parquet schema of the records to be read + * @param catalystSchema Catalyst schema of the rows to be constructed */ -private[parquet] class RowRecordMaterializer(root: CatalystConverter) +private[parquet] class RowRecordMaterializer(parquetSchema: MessageType, catalystSchema: StructType) extends RecordMaterializer[InternalRow] { - def this(parquetSchema: MessageType, attributes: Seq[Attribute]) = - this(CatalystConverter.createRootConverter(parquetSchema, attributes)) + private val rootConverter = new CatalystRowConverter(parquetSchema, catalystSchema, NoopUpdater) - override def getCurrentRecord: InternalRow = root.getCurrentRecord + override def getCurrentRecord: InternalRow = rootConverter.currentRow - override def getRootConverter: GroupConverter = root.asInstanceOf[GroupConverter] + override def getRootConverter: GroupConverter = rootConverter } -/** - * A `parquet.hadoop.api.ReadSupport` for Row objects. - */ private[parquet] class RowReadSupport extends ReadSupport[InternalRow] with Logging { - override def prepareForRead( conf: Configuration, - stringMap: java.util.Map[String, String], + keyValueMetaData: util.Map[String, String], fileSchema: MessageType, readContext: ReadContext): RecordMaterializer[InternalRow] = { - log.debug(s"preparing for read with Parquet file schema $fileSchema") - // Note: this very much imitates AvroParquet - val parquetSchema = readContext.getRequestedSchema - var schema: Seq[Attribute] = null - - if (readContext.getReadSupportMetadata != null) { - // first try to find the read schema inside the metadata (can result from projections) - if ( - readContext - .getReadSupportMetadata - .get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) != null) { - schema = ParquetTypesConverter.convertFromString( - readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA)) - } else { - // if unavailable, try the schema that was read originally from the file or provided - // during the creation of the Parquet relation - if (readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_METADATA_KEY) != null) { - schema = ParquetTypesConverter.convertFromString( - readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) - } + log.debug(s"Preparing for read Parquet file with message type: $fileSchema") + + val toCatalyst = new CatalystSchemaConverter(conf) + val parquetRequestedSchema = readContext.getRequestedSchema + + val catalystRequestedSchema = + Option(readContext.getReadSupportMetadata).map(_.toMap).flatMap { metadata => + metadata + // First tries to read requested schema, which may result from projections + .get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) + // If not available, tries to read Catalyst schema from file metadata. It's only + // available if the target file is written by Spark SQL. + .orElse(metadata.get(RowReadSupport.SPARK_METADATA_KEY)) + }.map(StructType.fromString).getOrElse { + logDebug("Catalyst schema not available, falling back to Parquet schema") + toCatalyst.convert(parquetRequestedSchema) } - } - // if both unavailable, fall back to deducing the schema from the given Parquet schema - // TODO: Why it can be null? - if (schema == null) { - log.debug("falling back to Parquet read schema") - schema = ParquetTypesConverter.convertToAttributes(parquetSchema, false, true) - } - log.debug(s"list of attributes that will be read: $schema") - new RowRecordMaterializer(parquetSchema, schema) + + logDebug(s"Catalyst schema used to read Parquet files: $catalystRequestedSchema") + new RowRecordMaterializer(parquetRequestedSchema, catalystRequestedSchema) } - override def init( - configuration: Configuration, - keyValueMetaData: java.util.Map[String, String], - fileSchema: MessageType): ReadContext = { - var parquetSchema = fileSchema - val metadata = new JHashMap[String, String]() - val requestedAttributes = RowReadSupport.getRequestedSchema(configuration) - - if (requestedAttributes != null) { - // If the parquet file is thrift derived, there is a good chance that - // it will have the thrift class in metadata. - val isThriftDerived = keyValueMetaData.keySet().contains("thrift.class") - parquetSchema = ParquetTypesConverter.convertFromAttributes(requestedAttributes) - metadata.put( - RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(requestedAttributes)) - } + override def init(context: InitContext): ReadContext = { + val conf = context.getConfiguration + + // If the target file was written by Spark SQL, we should be able to find a serialized Catalyst + // schema of this file from its the metadata. + val maybeRowSchema = Option(conf.get(RowWriteSupport.SPARK_ROW_SCHEMA)) + + // Optional schema of requested columns, in the form of a string serialized from a Catalyst + // `StructType` containing all requested columns. + val maybeRequestedSchema = Option(conf.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA)) + + // Below we construct a Parquet schema containing all requested columns. This schema tells + // Parquet which columns to read. + // + // If `maybeRequestedSchema` is defined, we assemble an equivalent Parquet schema. Otherwise, + // we have to fallback to the full file schema which contains all columns in the file. + // Obviously this may waste IO bandwidth since it may read more columns than requested. + // + // Two things to note: + // + // 1. It's possible that some requested columns don't exist in the target Parquet file. For + // example, in the case of schema merging, the globally merged schema may contain extra + // columns gathered from other Parquet files. These columns will be simply filled with nulls + // when actually reading the target Parquet file. + // + // 2. When `maybeRequestedSchema` is available, we can't simply convert the Catalyst schema to + // Parquet schema using `CatalystSchemaConverter`, because the mapping is not unique due to + // non-standard behaviors of some Parquet libraries/tools. For example, a Parquet file + // containing a single integer array field `f1` may have the following legacy 2-level + // structure: + // + // message root { + // optional group f1 (LIST) { + // required INT32 element; + // } + // } + // + // while `CatalystSchemaConverter` may generate a standard 3-level structure: + // + // message root { + // optional group f1 (LIST) { + // repeated group list { + // required INT32 element; + // } + // } + // } + // + // Apparently, we can't use the 2nd schema to read the target Parquet file as they have + // different physical structures. + val parquetRequestedSchema = + maybeRequestedSchema.fold(context.getFileSchema) { schemaString => + val toParquet = new CatalystSchemaConverter(conf) + val fileSchema = context.getFileSchema.asGroupType() + val fileFieldNames = fileSchema.getFields.map(_.getName).toSet + + StructType + // Deserializes the Catalyst schema of requested columns + .fromString(schemaString) + .map { field => + if (fileFieldNames.contains(field.name)) { + // If the field exists in the target Parquet file, extracts the field type from the + // full file schema and makes a single-field Parquet schema + new MessageType("root", fileSchema.getType(field.name)) + } else { + // Otherwise, just resorts to `CatalystSchemaConverter` + toParquet.convert(StructType(Array(field))) + } + } + // Merges all single-field Parquet schemas to form a complete schema for all requested + // columns. Note that it's possible that no columns are requested at all (e.g., count + // some partition column of a partitioned Parquet table). That's why `fold` is used here + // and always fallback to an empty Parquet schema. + .fold(new MessageType("root")) { + _ union _ + } + } - val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) - if (origAttributesStr != null) { - metadata.put(RowReadSupport.SPARK_METADATA_KEY, origAttributesStr) - } + val metadata = + Map.empty[String, String] ++ + maybeRequestedSchema.map(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA -> _) ++ + maybeRowSchema.map(RowWriteSupport.SPARK_ROW_SCHEMA -> _) - new ReadSupport.ReadContext(parquetSchema, metadata) + logInfo(s"Going to read Parquet file with these requested columns: $parquetRequestedSchema") + new ReadContext(parquetRequestedSchema, metadata) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index ce456e7fbe17e..01dd6f471bd7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -259,6 +259,10 @@ private[sql] class ParquetRelation2( broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown + val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString + val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp + val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec + // Create the function to set variable Parquet confs at both driver and executor side. val initLocalJobFuncOpt = ParquetRelation2.initializeLocalJobFunc( @@ -266,7 +270,11 @@ private[sql] class ParquetRelation2( filters, dataSchema, useMetadataCache, - parquetFilterPushDown) _ + parquetFilterPushDown, + assumeBinaryIsString, + assumeInt96IsTimestamp, + followParquetFormatSpec) _ + // Create the function to set input paths at the driver side. val setInputPaths = ParquetRelation2.initializeDriverSideJobFunc(inputFiles) _ @@ -471,9 +479,12 @@ private[sql] object ParquetRelation2 extends Logging { filters: Array[Filter], dataSchema: StructType, useMetadataCache: Boolean, - parquetFilterPushDown: Boolean)(job: Job): Unit = { + parquetFilterPushDown: Boolean, + assumeBinaryIsString: Boolean, + assumeInt96IsTimestamp: Boolean, + followParquetFormatSpec: Boolean)(job: Job): Unit = { val conf = job.getConfiguration - conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[RowReadSupport].getName()) + conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[RowReadSupport].getName) // Try to push down filters when filter push-down is enabled. if (parquetFilterPushDown) { @@ -497,6 +508,11 @@ private[sql] object ParquetRelation2 extends Logging { // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata conf.setBoolean(SQLConf.PARQUET_CACHE_METADATA.key, useMetadataCache) + + // Sets flags for Parquet schema conversion + conf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, assumeBinaryIsString) + conf.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, assumeInt96IsTimestamp) + conf.setBoolean(SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, followParquetFormatSpec) } /** This closure sets input paths at the driver side. */ diff --git a/sql/core/src/test/README.md b/sql/core/src/test/README.md new file mode 100644 index 0000000000000..3dd9861b4896d --- /dev/null +++ b/sql/core/src/test/README.md @@ -0,0 +1,33 @@ +# Notes for Parquet compatibility tests + +The following directories and files are used for Parquet compatibility tests: + +``` +. +├── README.md # This file +├── avro +│   ├── parquet-compat.avdl # Testing Avro IDL +│   └── parquet-compat.avpr # !! NO TOUCH !! Protocol file generated from parquet-compat.avdl +├── gen-java # !! NO TOUCH !! Generated Java code +├── scripts +│   └── gen-code.sh # Script used to generate Java code for Thrift and Avro +└── thrift + └── parquet-compat.thrift # Testing Thrift schema +``` + +Generated Java code are used in the following test suites: + +- `org.apache.spark.sql.parquet.ParquetAvroCompatibilitySuite` +- `org.apache.spark.sql.parquet.ParquetThriftCompatibilitySuite` + +To avoid code generation during build time, Java code generated from testing Thrift schema and Avro IDL are also checked in. + +When updating the testing Thrift schema and Avro IDL, please run `gen-code.sh` to update all the generated Java code. + +## Prerequisites + +Please ensure `avro-tools` and `thrift` are installed. You may install these two on Mac OS X via: + +```bash +$ brew install thrift avro-tools +``` diff --git a/sql/core/src/test/avro/parquet-compat.avdl b/sql/core/src/test/avro/parquet-compat.avdl new file mode 100644 index 0000000000000..24729f6143e6c --- /dev/null +++ b/sql/core/src/test/avro/parquet-compat.avdl @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This is a test protocol for testing parquet-avro compatibility. +@namespace("org.apache.spark.sql.parquet.test.avro") +protocol CompatibilityTest { + record Nested { + array nested_ints_column; + string nested_string_column; + } + + record ParquetAvroCompat { + boolean bool_column; + int int_column; + long long_column; + float float_column; + double double_column; + bytes binary_column; + string string_column; + + union { null, boolean } maybe_bool_column; + union { null, int } maybe_int_column; + union { null, long } maybe_long_column; + union { null, float } maybe_float_column; + union { null, double } maybe_double_column; + union { null, bytes } maybe_binary_column; + union { null, string } maybe_string_column; + + array strings_column; + map string_to_int_column; + map> complex_column; + } +} diff --git a/sql/core/src/test/avro/parquet-compat.avpr b/sql/core/src/test/avro/parquet-compat.avpr new file mode 100644 index 0000000000000..a83b7c990dd2e --- /dev/null +++ b/sql/core/src/test/avro/parquet-compat.avpr @@ -0,0 +1,86 @@ +{ + "protocol" : "CompatibilityTest", + "namespace" : "org.apache.spark.sql.parquet.test.avro", + "types" : [ { + "type" : "record", + "name" : "Nested", + "fields" : [ { + "name" : "nested_ints_column", + "type" : { + "type" : "array", + "items" : "int" + } + }, { + "name" : "nested_string_column", + "type" : "string" + } ] + }, { + "type" : "record", + "name" : "ParquetAvroCompat", + "fields" : [ { + "name" : "bool_column", + "type" : "boolean" + }, { + "name" : "int_column", + "type" : "int" + }, { + "name" : "long_column", + "type" : "long" + }, { + "name" : "float_column", + "type" : "float" + }, { + "name" : "double_column", + "type" : "double" + }, { + "name" : "binary_column", + "type" : "bytes" + }, { + "name" : "string_column", + "type" : "string" + }, { + "name" : "maybe_bool_column", + "type" : [ "null", "boolean" ] + }, { + "name" : "maybe_int_column", + "type" : [ "null", "int" ] + }, { + "name" : "maybe_long_column", + "type" : [ "null", "long" ] + }, { + "name" : "maybe_float_column", + "type" : [ "null", "float" ] + }, { + "name" : "maybe_double_column", + "type" : [ "null", "double" ] + }, { + "name" : "maybe_binary_column", + "type" : [ "null", "bytes" ] + }, { + "name" : "maybe_string_column", + "type" : [ "null", "string" ] + }, { + "name" : "strings_column", + "type" : { + "type" : "array", + "items" : "string" + } + }, { + "name" : "string_to_int_column", + "type" : { + "type" : "map", + "values" : "int" + } + }, { + "name" : "complex_column", + "type" : { + "type" : "map", + "values" : { + "type" : "array", + "items" : "Nested" + } + } + } ] + } ], + "messages" : { } +} \ No newline at end of file diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java new file mode 100644 index 0000000000000..daec65a5bbe57 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java @@ -0,0 +1,17 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.parquet.test.avro; + +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public interface CompatibilityTest { + public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"CompatibilityTest\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"types\":[{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"Nested\"},\"avro.java.string\":\"String\"}}]}],\"messages\":{}}"); + + @SuppressWarnings("all") + public interface Callback extends CompatibilityTest { + public static final org.apache.avro.Protocol PROTOCOL = org.apache.spark.sql.parquet.test.avro.CompatibilityTest.PROTOCOL; + } +} \ No newline at end of file diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java new file mode 100644 index 0000000000000..051f1ee903863 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java @@ -0,0 +1,196 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class Nested extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"Nested\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List nested_ints_column; + @Deprecated public java.lang.String nested_string_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public Nested() {} + + /** + * All-args constructor. + */ + public Nested(java.util.List nested_ints_column, java.lang.String nested_string_column) { + this.nested_ints_column = nested_ints_column; + this.nested_string_column = nested_string_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return nested_ints_column; + case 1: return nested_string_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: nested_ints_column = (java.util.List)value$; break; + case 1: nested_string_column = (java.lang.String)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'nested_ints_column' field. + */ + public java.util.List getNestedIntsColumn() { + return nested_ints_column; + } + + /** + * Sets the value of the 'nested_ints_column' field. + * @param value the value to set. + */ + public void setNestedIntsColumn(java.util.List value) { + this.nested_ints_column = value; + } + + /** + * Gets the value of the 'nested_string_column' field. + */ + public java.lang.String getNestedStringColumn() { + return nested_string_column; + } + + /** + * Sets the value of the 'nested_string_column' field. + * @param value the value to set. + */ + public void setNestedStringColumn(java.lang.String value) { + this.nested_string_column = value; + } + + /** Creates a new Nested RecordBuilder */ + public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder() { + return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(); + } + + /** Creates a new Nested RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { + return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(other); + } + + /** Creates a new Nested RecordBuilder by copying an existing Nested instance */ + public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.Nested other) { + return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(other); + } + + /** + * RecordBuilder for Nested instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List nested_ints_column; + private java.lang.String nested_string_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.parquet.test.avro.Nested.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { + super(other); + if (isValidValue(fields()[0], other.nested_ints_column)) { + this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.nested_string_column)) { + this.nested_string_column = data().deepCopy(fields()[1].schema(), other.nested_string_column); + fieldSetFlags()[1] = true; + } + } + + /** Creates a Builder by copying an existing Nested instance */ + private Builder(org.apache.spark.sql.parquet.test.avro.Nested other) { + super(org.apache.spark.sql.parquet.test.avro.Nested.SCHEMA$); + if (isValidValue(fields()[0], other.nested_ints_column)) { + this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.nested_string_column)) { + this.nested_string_column = data().deepCopy(fields()[1].schema(), other.nested_string_column); + fieldSetFlags()[1] = true; + } + } + + /** Gets the value of the 'nested_ints_column' field */ + public java.util.List getNestedIntsColumn() { + return nested_ints_column; + } + + /** Sets the value of the 'nested_ints_column' field */ + public org.apache.spark.sql.parquet.test.avro.Nested.Builder setNestedIntsColumn(java.util.List value) { + validate(fields()[0], value); + this.nested_ints_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'nested_ints_column' field has been set */ + public boolean hasNestedIntsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'nested_ints_column' field */ + public org.apache.spark.sql.parquet.test.avro.Nested.Builder clearNestedIntsColumn() { + nested_ints_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'nested_string_column' field */ + public java.lang.String getNestedStringColumn() { + return nested_string_column; + } + + /** Sets the value of the 'nested_string_column' field */ + public org.apache.spark.sql.parquet.test.avro.Nested.Builder setNestedStringColumn(java.lang.String value) { + validate(fields()[1], value); + this.nested_string_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'nested_string_column' field has been set */ + public boolean hasNestedStringColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'nested_string_column' field */ + public org.apache.spark.sql.parquet.test.avro.Nested.Builder clearNestedStringColumn() { + nested_string_column = null; + fieldSetFlags()[1] = false; + return this; + } + + @Override + public Nested build() { + try { + Nested record = new Nested(); + record.nested_ints_column = fieldSetFlags()[0] ? this.nested_ints_column : (java.util.List) defaultValue(fields()[0]); + record.nested_string_column = fieldSetFlags()[1] ? this.nested_string_column : (java.lang.String) defaultValue(fields()[1]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java new file mode 100644 index 0000000000000..354c9d73cca31 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java @@ -0,0 +1,1001 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}},\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public boolean bool_column; + @Deprecated public int int_column; + @Deprecated public long long_column; + @Deprecated public float float_column; + @Deprecated public double double_column; + @Deprecated public java.nio.ByteBuffer binary_column; + @Deprecated public java.lang.String string_column; + @Deprecated public java.lang.Boolean maybe_bool_column; + @Deprecated public java.lang.Integer maybe_int_column; + @Deprecated public java.lang.Long maybe_long_column; + @Deprecated public java.lang.Float maybe_float_column; + @Deprecated public java.lang.Double maybe_double_column; + @Deprecated public java.nio.ByteBuffer maybe_binary_column; + @Deprecated public java.lang.String maybe_string_column; + @Deprecated public java.util.List strings_column; + @Deprecated public java.util.Map string_to_int_column; + @Deprecated public java.util.Map> complex_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public ParquetAvroCompat() {} + + /** + * All-args constructor. + */ + public ParquetAvroCompat(java.lang.Boolean bool_column, java.lang.Integer int_column, java.lang.Long long_column, java.lang.Float float_column, java.lang.Double double_column, java.nio.ByteBuffer binary_column, java.lang.String string_column, java.lang.Boolean maybe_bool_column, java.lang.Integer maybe_int_column, java.lang.Long maybe_long_column, java.lang.Float maybe_float_column, java.lang.Double maybe_double_column, java.nio.ByteBuffer maybe_binary_column, java.lang.String maybe_string_column, java.util.List strings_column, java.util.Map string_to_int_column, java.util.Map> complex_column) { + this.bool_column = bool_column; + this.int_column = int_column; + this.long_column = long_column; + this.float_column = float_column; + this.double_column = double_column; + this.binary_column = binary_column; + this.string_column = string_column; + this.maybe_bool_column = maybe_bool_column; + this.maybe_int_column = maybe_int_column; + this.maybe_long_column = maybe_long_column; + this.maybe_float_column = maybe_float_column; + this.maybe_double_column = maybe_double_column; + this.maybe_binary_column = maybe_binary_column; + this.maybe_string_column = maybe_string_column; + this.strings_column = strings_column; + this.string_to_int_column = string_to_int_column; + this.complex_column = complex_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return bool_column; + case 1: return int_column; + case 2: return long_column; + case 3: return float_column; + case 4: return double_column; + case 5: return binary_column; + case 6: return string_column; + case 7: return maybe_bool_column; + case 8: return maybe_int_column; + case 9: return maybe_long_column; + case 10: return maybe_float_column; + case 11: return maybe_double_column; + case 12: return maybe_binary_column; + case 13: return maybe_string_column; + case 14: return strings_column; + case 15: return string_to_int_column; + case 16: return complex_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: bool_column = (java.lang.Boolean)value$; break; + case 1: int_column = (java.lang.Integer)value$; break; + case 2: long_column = (java.lang.Long)value$; break; + case 3: float_column = (java.lang.Float)value$; break; + case 4: double_column = (java.lang.Double)value$; break; + case 5: binary_column = (java.nio.ByteBuffer)value$; break; + case 6: string_column = (java.lang.String)value$; break; + case 7: maybe_bool_column = (java.lang.Boolean)value$; break; + case 8: maybe_int_column = (java.lang.Integer)value$; break; + case 9: maybe_long_column = (java.lang.Long)value$; break; + case 10: maybe_float_column = (java.lang.Float)value$; break; + case 11: maybe_double_column = (java.lang.Double)value$; break; + case 12: maybe_binary_column = (java.nio.ByteBuffer)value$; break; + case 13: maybe_string_column = (java.lang.String)value$; break; + case 14: strings_column = (java.util.List)value$; break; + case 15: string_to_int_column = (java.util.Map)value$; break; + case 16: complex_column = (java.util.Map>)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'bool_column' field. + */ + public java.lang.Boolean getBoolColumn() { + return bool_column; + } + + /** + * Sets the value of the 'bool_column' field. + * @param value the value to set. + */ + public void setBoolColumn(java.lang.Boolean value) { + this.bool_column = value; + } + + /** + * Gets the value of the 'int_column' field. + */ + public java.lang.Integer getIntColumn() { + return int_column; + } + + /** + * Sets the value of the 'int_column' field. + * @param value the value to set. + */ + public void setIntColumn(java.lang.Integer value) { + this.int_column = value; + } + + /** + * Gets the value of the 'long_column' field. + */ + public java.lang.Long getLongColumn() { + return long_column; + } + + /** + * Sets the value of the 'long_column' field. + * @param value the value to set. + */ + public void setLongColumn(java.lang.Long value) { + this.long_column = value; + } + + /** + * Gets the value of the 'float_column' field. + */ + public java.lang.Float getFloatColumn() { + return float_column; + } + + /** + * Sets the value of the 'float_column' field. + * @param value the value to set. + */ + public void setFloatColumn(java.lang.Float value) { + this.float_column = value; + } + + /** + * Gets the value of the 'double_column' field. + */ + public java.lang.Double getDoubleColumn() { + return double_column; + } + + /** + * Sets the value of the 'double_column' field. + * @param value the value to set. + */ + public void setDoubleColumn(java.lang.Double value) { + this.double_column = value; + } + + /** + * Gets the value of the 'binary_column' field. + */ + public java.nio.ByteBuffer getBinaryColumn() { + return binary_column; + } + + /** + * Sets the value of the 'binary_column' field. + * @param value the value to set. + */ + public void setBinaryColumn(java.nio.ByteBuffer value) { + this.binary_column = value; + } + + /** + * Gets the value of the 'string_column' field. + */ + public java.lang.String getStringColumn() { + return string_column; + } + + /** + * Sets the value of the 'string_column' field. + * @param value the value to set. + */ + public void setStringColumn(java.lang.String value) { + this.string_column = value; + } + + /** + * Gets the value of the 'maybe_bool_column' field. + */ + public java.lang.Boolean getMaybeBoolColumn() { + return maybe_bool_column; + } + + /** + * Sets the value of the 'maybe_bool_column' field. + * @param value the value to set. + */ + public void setMaybeBoolColumn(java.lang.Boolean value) { + this.maybe_bool_column = value; + } + + /** + * Gets the value of the 'maybe_int_column' field. + */ + public java.lang.Integer getMaybeIntColumn() { + return maybe_int_column; + } + + /** + * Sets the value of the 'maybe_int_column' field. + * @param value the value to set. + */ + public void setMaybeIntColumn(java.lang.Integer value) { + this.maybe_int_column = value; + } + + /** + * Gets the value of the 'maybe_long_column' field. + */ + public java.lang.Long getMaybeLongColumn() { + return maybe_long_column; + } + + /** + * Sets the value of the 'maybe_long_column' field. + * @param value the value to set. + */ + public void setMaybeLongColumn(java.lang.Long value) { + this.maybe_long_column = value; + } + + /** + * Gets the value of the 'maybe_float_column' field. + */ + public java.lang.Float getMaybeFloatColumn() { + return maybe_float_column; + } + + /** + * Sets the value of the 'maybe_float_column' field. + * @param value the value to set. + */ + public void setMaybeFloatColumn(java.lang.Float value) { + this.maybe_float_column = value; + } + + /** + * Gets the value of the 'maybe_double_column' field. + */ + public java.lang.Double getMaybeDoubleColumn() { + return maybe_double_column; + } + + /** + * Sets the value of the 'maybe_double_column' field. + * @param value the value to set. + */ + public void setMaybeDoubleColumn(java.lang.Double value) { + this.maybe_double_column = value; + } + + /** + * Gets the value of the 'maybe_binary_column' field. + */ + public java.nio.ByteBuffer getMaybeBinaryColumn() { + return maybe_binary_column; + } + + /** + * Sets the value of the 'maybe_binary_column' field. + * @param value the value to set. + */ + public void setMaybeBinaryColumn(java.nio.ByteBuffer value) { + this.maybe_binary_column = value; + } + + /** + * Gets the value of the 'maybe_string_column' field. + */ + public java.lang.String getMaybeStringColumn() { + return maybe_string_column; + } + + /** + * Sets the value of the 'maybe_string_column' field. + * @param value the value to set. + */ + public void setMaybeStringColumn(java.lang.String value) { + this.maybe_string_column = value; + } + + /** + * Gets the value of the 'strings_column' field. + */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** + * Sets the value of the 'strings_column' field. + * @param value the value to set. + */ + public void setStringsColumn(java.util.List value) { + this.strings_column = value; + } + + /** + * Gets the value of the 'string_to_int_column' field. + */ + public java.util.Map getStringToIntColumn() { + return string_to_int_column; + } + + /** + * Sets the value of the 'string_to_int_column' field. + * @param value the value to set. + */ + public void setStringToIntColumn(java.util.Map value) { + this.string_to_int_column = value; + } + + /** + * Gets the value of the 'complex_column' field. + */ + public java.util.Map> getComplexColumn() { + return complex_column; + } + + /** + * Sets the value of the 'complex_column' field. + * @param value the value to set. + */ + public void setComplexColumn(java.util.Map> value) { + this.complex_column = value; + } + + /** Creates a new ParquetAvroCompat RecordBuilder */ + public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder() { + return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(); + } + + /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder other) { + return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(other); + } + + /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing ParquetAvroCompat instance */ + public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat other) { + return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(other); + } + + /** + * RecordBuilder for ParquetAvroCompat instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private boolean bool_column; + private int int_column; + private long long_column; + private float float_column; + private double double_column; + private java.nio.ByteBuffer binary_column; + private java.lang.String string_column; + private java.lang.Boolean maybe_bool_column; + private java.lang.Integer maybe_int_column; + private java.lang.Long maybe_long_column; + private java.lang.Float maybe_float_column; + private java.lang.Double maybe_double_column; + private java.nio.ByteBuffer maybe_binary_column; + private java.lang.String maybe_string_column; + private java.util.List strings_column; + private java.util.Map string_to_int_column; + private java.util.Map> complex_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder other) { + super(other); + if (isValidValue(fields()[0], other.bool_column)) { + this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int_column)) { + this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long_column)) { + this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float_column)) { + this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double_column)) { + this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.binary_column)) { + this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.string_column)) { + this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); + fieldSetFlags()[6] = true; + } + if (isValidValue(fields()[7], other.maybe_bool_column)) { + this.maybe_bool_column = data().deepCopy(fields()[7].schema(), other.maybe_bool_column); + fieldSetFlags()[7] = true; + } + if (isValidValue(fields()[8], other.maybe_int_column)) { + this.maybe_int_column = data().deepCopy(fields()[8].schema(), other.maybe_int_column); + fieldSetFlags()[8] = true; + } + if (isValidValue(fields()[9], other.maybe_long_column)) { + this.maybe_long_column = data().deepCopy(fields()[9].schema(), other.maybe_long_column); + fieldSetFlags()[9] = true; + } + if (isValidValue(fields()[10], other.maybe_float_column)) { + this.maybe_float_column = data().deepCopy(fields()[10].schema(), other.maybe_float_column); + fieldSetFlags()[10] = true; + } + if (isValidValue(fields()[11], other.maybe_double_column)) { + this.maybe_double_column = data().deepCopy(fields()[11].schema(), other.maybe_double_column); + fieldSetFlags()[11] = true; + } + if (isValidValue(fields()[12], other.maybe_binary_column)) { + this.maybe_binary_column = data().deepCopy(fields()[12].schema(), other.maybe_binary_column); + fieldSetFlags()[12] = true; + } + if (isValidValue(fields()[13], other.maybe_string_column)) { + this.maybe_string_column = data().deepCopy(fields()[13].schema(), other.maybe_string_column); + fieldSetFlags()[13] = true; + } + if (isValidValue(fields()[14], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[14].schema(), other.strings_column); + fieldSetFlags()[14] = true; + } + if (isValidValue(fields()[15], other.string_to_int_column)) { + this.string_to_int_column = data().deepCopy(fields()[15].schema(), other.string_to_int_column); + fieldSetFlags()[15] = true; + } + if (isValidValue(fields()[16], other.complex_column)) { + this.complex_column = data().deepCopy(fields()[16].schema(), other.complex_column); + fieldSetFlags()[16] = true; + } + } + + /** Creates a Builder by copying an existing ParquetAvroCompat instance */ + private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat other) { + super(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + if (isValidValue(fields()[0], other.bool_column)) { + this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int_column)) { + this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long_column)) { + this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float_column)) { + this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double_column)) { + this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.binary_column)) { + this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.string_column)) { + this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); + fieldSetFlags()[6] = true; + } + if (isValidValue(fields()[7], other.maybe_bool_column)) { + this.maybe_bool_column = data().deepCopy(fields()[7].schema(), other.maybe_bool_column); + fieldSetFlags()[7] = true; + } + if (isValidValue(fields()[8], other.maybe_int_column)) { + this.maybe_int_column = data().deepCopy(fields()[8].schema(), other.maybe_int_column); + fieldSetFlags()[8] = true; + } + if (isValidValue(fields()[9], other.maybe_long_column)) { + this.maybe_long_column = data().deepCopy(fields()[9].schema(), other.maybe_long_column); + fieldSetFlags()[9] = true; + } + if (isValidValue(fields()[10], other.maybe_float_column)) { + this.maybe_float_column = data().deepCopy(fields()[10].schema(), other.maybe_float_column); + fieldSetFlags()[10] = true; + } + if (isValidValue(fields()[11], other.maybe_double_column)) { + this.maybe_double_column = data().deepCopy(fields()[11].schema(), other.maybe_double_column); + fieldSetFlags()[11] = true; + } + if (isValidValue(fields()[12], other.maybe_binary_column)) { + this.maybe_binary_column = data().deepCopy(fields()[12].schema(), other.maybe_binary_column); + fieldSetFlags()[12] = true; + } + if (isValidValue(fields()[13], other.maybe_string_column)) { + this.maybe_string_column = data().deepCopy(fields()[13].schema(), other.maybe_string_column); + fieldSetFlags()[13] = true; + } + if (isValidValue(fields()[14], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[14].schema(), other.strings_column); + fieldSetFlags()[14] = true; + } + if (isValidValue(fields()[15], other.string_to_int_column)) { + this.string_to_int_column = data().deepCopy(fields()[15].schema(), other.string_to_int_column); + fieldSetFlags()[15] = true; + } + if (isValidValue(fields()[16], other.complex_column)) { + this.complex_column = data().deepCopy(fields()[16].schema(), other.complex_column); + fieldSetFlags()[16] = true; + } + } + + /** Gets the value of the 'bool_column' field */ + public java.lang.Boolean getBoolColumn() { + return bool_column; + } + + /** Sets the value of the 'bool_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setBoolColumn(boolean value) { + validate(fields()[0], value); + this.bool_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'bool_column' field has been set */ + public boolean hasBoolColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'bool_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearBoolColumn() { + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'int_column' field */ + public java.lang.Integer getIntColumn() { + return int_column; + } + + /** Sets the value of the 'int_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setIntColumn(int value) { + validate(fields()[1], value); + this.int_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'int_column' field has been set */ + public boolean hasIntColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'int_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearIntColumn() { + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'long_column' field */ + public java.lang.Long getLongColumn() { + return long_column; + } + + /** Sets the value of the 'long_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setLongColumn(long value) { + validate(fields()[2], value); + this.long_column = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'long_column' field has been set */ + public boolean hasLongColumn() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'long_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearLongColumn() { + fieldSetFlags()[2] = false; + return this; + } + + /** Gets the value of the 'float_column' field */ + public java.lang.Float getFloatColumn() { + return float_column; + } + + /** Sets the value of the 'float_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setFloatColumn(float value) { + validate(fields()[3], value); + this.float_column = value; + fieldSetFlags()[3] = true; + return this; + } + + /** Checks whether the 'float_column' field has been set */ + public boolean hasFloatColumn() { + return fieldSetFlags()[3]; + } + + /** Clears the value of the 'float_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearFloatColumn() { + fieldSetFlags()[3] = false; + return this; + } + + /** Gets the value of the 'double_column' field */ + public java.lang.Double getDoubleColumn() { + return double_column; + } + + /** Sets the value of the 'double_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setDoubleColumn(double value) { + validate(fields()[4], value); + this.double_column = value; + fieldSetFlags()[4] = true; + return this; + } + + /** Checks whether the 'double_column' field has been set */ + public boolean hasDoubleColumn() { + return fieldSetFlags()[4]; + } + + /** Clears the value of the 'double_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearDoubleColumn() { + fieldSetFlags()[4] = false; + return this; + } + + /** Gets the value of the 'binary_column' field */ + public java.nio.ByteBuffer getBinaryColumn() { + return binary_column; + } + + /** Sets the value of the 'binary_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setBinaryColumn(java.nio.ByteBuffer value) { + validate(fields()[5], value); + this.binary_column = value; + fieldSetFlags()[5] = true; + return this; + } + + /** Checks whether the 'binary_column' field has been set */ + public boolean hasBinaryColumn() { + return fieldSetFlags()[5]; + } + + /** Clears the value of the 'binary_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearBinaryColumn() { + binary_column = null; + fieldSetFlags()[5] = false; + return this; + } + + /** Gets the value of the 'string_column' field */ + public java.lang.String getStringColumn() { + return string_column; + } + + /** Sets the value of the 'string_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringColumn(java.lang.String value) { + validate(fields()[6], value); + this.string_column = value; + fieldSetFlags()[6] = true; + return this; + } + + /** Checks whether the 'string_column' field has been set */ + public boolean hasStringColumn() { + return fieldSetFlags()[6]; + } + + /** Clears the value of the 'string_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringColumn() { + string_column = null; + fieldSetFlags()[6] = false; + return this; + } + + /** Gets the value of the 'maybe_bool_column' field */ + public java.lang.Boolean getMaybeBoolColumn() { + return maybe_bool_column; + } + + /** Sets the value of the 'maybe_bool_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBoolColumn(java.lang.Boolean value) { + validate(fields()[7], value); + this.maybe_bool_column = value; + fieldSetFlags()[7] = true; + return this; + } + + /** Checks whether the 'maybe_bool_column' field has been set */ + public boolean hasMaybeBoolColumn() { + return fieldSetFlags()[7]; + } + + /** Clears the value of the 'maybe_bool_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBoolColumn() { + maybe_bool_column = null; + fieldSetFlags()[7] = false; + return this; + } + + /** Gets the value of the 'maybe_int_column' field */ + public java.lang.Integer getMaybeIntColumn() { + return maybe_int_column; + } + + /** Sets the value of the 'maybe_int_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeIntColumn(java.lang.Integer value) { + validate(fields()[8], value); + this.maybe_int_column = value; + fieldSetFlags()[8] = true; + return this; + } + + /** Checks whether the 'maybe_int_column' field has been set */ + public boolean hasMaybeIntColumn() { + return fieldSetFlags()[8]; + } + + /** Clears the value of the 'maybe_int_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeIntColumn() { + maybe_int_column = null; + fieldSetFlags()[8] = false; + return this; + } + + /** Gets the value of the 'maybe_long_column' field */ + public java.lang.Long getMaybeLongColumn() { + return maybe_long_column; + } + + /** Sets the value of the 'maybe_long_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeLongColumn(java.lang.Long value) { + validate(fields()[9], value); + this.maybe_long_column = value; + fieldSetFlags()[9] = true; + return this; + } + + /** Checks whether the 'maybe_long_column' field has been set */ + public boolean hasMaybeLongColumn() { + return fieldSetFlags()[9]; + } + + /** Clears the value of the 'maybe_long_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeLongColumn() { + maybe_long_column = null; + fieldSetFlags()[9] = false; + return this; + } + + /** Gets the value of the 'maybe_float_column' field */ + public java.lang.Float getMaybeFloatColumn() { + return maybe_float_column; + } + + /** Sets the value of the 'maybe_float_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeFloatColumn(java.lang.Float value) { + validate(fields()[10], value); + this.maybe_float_column = value; + fieldSetFlags()[10] = true; + return this; + } + + /** Checks whether the 'maybe_float_column' field has been set */ + public boolean hasMaybeFloatColumn() { + return fieldSetFlags()[10]; + } + + /** Clears the value of the 'maybe_float_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeFloatColumn() { + maybe_float_column = null; + fieldSetFlags()[10] = false; + return this; + } + + /** Gets the value of the 'maybe_double_column' field */ + public java.lang.Double getMaybeDoubleColumn() { + return maybe_double_column; + } + + /** Sets the value of the 'maybe_double_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeDoubleColumn(java.lang.Double value) { + validate(fields()[11], value); + this.maybe_double_column = value; + fieldSetFlags()[11] = true; + return this; + } + + /** Checks whether the 'maybe_double_column' field has been set */ + public boolean hasMaybeDoubleColumn() { + return fieldSetFlags()[11]; + } + + /** Clears the value of the 'maybe_double_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeDoubleColumn() { + maybe_double_column = null; + fieldSetFlags()[11] = false; + return this; + } + + /** Gets the value of the 'maybe_binary_column' field */ + public java.nio.ByteBuffer getMaybeBinaryColumn() { + return maybe_binary_column; + } + + /** Sets the value of the 'maybe_binary_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBinaryColumn(java.nio.ByteBuffer value) { + validate(fields()[12], value); + this.maybe_binary_column = value; + fieldSetFlags()[12] = true; + return this; + } + + /** Checks whether the 'maybe_binary_column' field has been set */ + public boolean hasMaybeBinaryColumn() { + return fieldSetFlags()[12]; + } + + /** Clears the value of the 'maybe_binary_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBinaryColumn() { + maybe_binary_column = null; + fieldSetFlags()[12] = false; + return this; + } + + /** Gets the value of the 'maybe_string_column' field */ + public java.lang.String getMaybeStringColumn() { + return maybe_string_column; + } + + /** Sets the value of the 'maybe_string_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeStringColumn(java.lang.String value) { + validate(fields()[13], value); + this.maybe_string_column = value; + fieldSetFlags()[13] = true; + return this; + } + + /** Checks whether the 'maybe_string_column' field has been set */ + public boolean hasMaybeStringColumn() { + return fieldSetFlags()[13]; + } + + /** Clears the value of the 'maybe_string_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeStringColumn() { + maybe_string_column = null; + fieldSetFlags()[13] = false; + return this; + } + + /** Gets the value of the 'strings_column' field */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** Sets the value of the 'strings_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringsColumn(java.util.List value) { + validate(fields()[14], value); + this.strings_column = value; + fieldSetFlags()[14] = true; + return this; + } + + /** Checks whether the 'strings_column' field has been set */ + public boolean hasStringsColumn() { + return fieldSetFlags()[14]; + } + + /** Clears the value of the 'strings_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringsColumn() { + strings_column = null; + fieldSetFlags()[14] = false; + return this; + } + + /** Gets the value of the 'string_to_int_column' field */ + public java.util.Map getStringToIntColumn() { + return string_to_int_column; + } + + /** Sets the value of the 'string_to_int_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringToIntColumn(java.util.Map value) { + validate(fields()[15], value); + this.string_to_int_column = value; + fieldSetFlags()[15] = true; + return this; + } + + /** Checks whether the 'string_to_int_column' field has been set */ + public boolean hasStringToIntColumn() { + return fieldSetFlags()[15]; + } + + /** Clears the value of the 'string_to_int_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringToIntColumn() { + string_to_int_column = null; + fieldSetFlags()[15] = false; + return this; + } + + /** Gets the value of the 'complex_column' field */ + public java.util.Map> getComplexColumn() { + return complex_column; + } + + /** Sets the value of the 'complex_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setComplexColumn(java.util.Map> value) { + validate(fields()[16], value); + this.complex_column = value; + fieldSetFlags()[16] = true; + return this; + } + + /** Checks whether the 'complex_column' field has been set */ + public boolean hasComplexColumn() { + return fieldSetFlags()[16]; + } + + /** Clears the value of the 'complex_column' field */ + public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearComplexColumn() { + complex_column = null; + fieldSetFlags()[16] = false; + return this; + } + + @Override + public ParquetAvroCompat build() { + try { + ParquetAvroCompat record = new ParquetAvroCompat(); + record.bool_column = fieldSetFlags()[0] ? this.bool_column : (java.lang.Boolean) defaultValue(fields()[0]); + record.int_column = fieldSetFlags()[1] ? this.int_column : (java.lang.Integer) defaultValue(fields()[1]); + record.long_column = fieldSetFlags()[2] ? this.long_column : (java.lang.Long) defaultValue(fields()[2]); + record.float_column = fieldSetFlags()[3] ? this.float_column : (java.lang.Float) defaultValue(fields()[3]); + record.double_column = fieldSetFlags()[4] ? this.double_column : (java.lang.Double) defaultValue(fields()[4]); + record.binary_column = fieldSetFlags()[5] ? this.binary_column : (java.nio.ByteBuffer) defaultValue(fields()[5]); + record.string_column = fieldSetFlags()[6] ? this.string_column : (java.lang.String) defaultValue(fields()[6]); + record.maybe_bool_column = fieldSetFlags()[7] ? this.maybe_bool_column : (java.lang.Boolean) defaultValue(fields()[7]); + record.maybe_int_column = fieldSetFlags()[8] ? this.maybe_int_column : (java.lang.Integer) defaultValue(fields()[8]); + record.maybe_long_column = fieldSetFlags()[9] ? this.maybe_long_column : (java.lang.Long) defaultValue(fields()[9]); + record.maybe_float_column = fieldSetFlags()[10] ? this.maybe_float_column : (java.lang.Float) defaultValue(fields()[10]); + record.maybe_double_column = fieldSetFlags()[11] ? this.maybe_double_column : (java.lang.Double) defaultValue(fields()[11]); + record.maybe_binary_column = fieldSetFlags()[12] ? this.maybe_binary_column : (java.nio.ByteBuffer) defaultValue(fields()[12]); + record.maybe_string_column = fieldSetFlags()[13] ? this.maybe_string_column : (java.lang.String) defaultValue(fields()[13]); + record.strings_column = fieldSetFlags()[14] ? this.strings_column : (java.util.List) defaultValue(fields()[14]); + record.string_to_int_column = fieldSetFlags()[15] ? this.string_to_int_column : (java.util.Map) defaultValue(fields()[15]); + record.complex_column = fieldSetFlags()[16] ? this.complex_column : (java.util.Map>) defaultValue(fields()[16]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Nested.java b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Nested.java new file mode 100644 index 0000000000000..281e60cc3ae34 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Nested.java @@ -0,0 +1,541 @@ +/** + * Autogenerated by Thrift Compiler (0.9.2) + * + * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING + * @generated + */ +package org.apache.spark.sql.parquet.test.thrift; + +import org.apache.thrift.scheme.IScheme; +import org.apache.thrift.scheme.SchemeFactory; +import org.apache.thrift.scheme.StandardScheme; + +import org.apache.thrift.scheme.TupleScheme; +import org.apache.thrift.protocol.TTupleProtocol; +import org.apache.thrift.protocol.TProtocolException; +import org.apache.thrift.EncodingUtils; +import org.apache.thrift.TException; +import org.apache.thrift.async.AsyncMethodCallback; +import org.apache.thrift.server.AbstractNonblockingServer.*; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.HashMap; +import java.util.EnumMap; +import java.util.Set; +import java.util.HashSet; +import java.util.EnumSet; +import java.util.Collections; +import java.util.BitSet; +import java.nio.ByteBuffer; +import java.util.Arrays; +import javax.annotation.Generated; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@SuppressWarnings({"cast", "rawtypes", "serial", "unchecked"}) +@Generated(value = "Autogenerated by Thrift Compiler (0.9.2)", date = "2015-7-7") +public class Nested implements org.apache.thrift.TBase, java.io.Serializable, Cloneable, Comparable { + private static final org.apache.thrift.protocol.TStruct STRUCT_DESC = new org.apache.thrift.protocol.TStruct("Nested"); + + private static final org.apache.thrift.protocol.TField NESTED_INTS_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("nestedIntsColumn", org.apache.thrift.protocol.TType.LIST, (short)1); + private static final org.apache.thrift.protocol.TField NESTED_STRING_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("nestedStringColumn", org.apache.thrift.protocol.TType.STRING, (short)2); + + private static final Map, SchemeFactory> schemes = new HashMap, SchemeFactory>(); + static { + schemes.put(StandardScheme.class, new NestedStandardSchemeFactory()); + schemes.put(TupleScheme.class, new NestedTupleSchemeFactory()); + } + + public List nestedIntsColumn; // required + public String nestedStringColumn; // required + + /** The set of fields this struct contains, along with convenience methods for finding and manipulating them. */ + public enum _Fields implements org.apache.thrift.TFieldIdEnum { + NESTED_INTS_COLUMN((short)1, "nestedIntsColumn"), + NESTED_STRING_COLUMN((short)2, "nestedStringColumn"); + + private static final Map byName = new HashMap(); + + static { + for (_Fields field : EnumSet.allOf(_Fields.class)) { + byName.put(field.getFieldName(), field); + } + } + + /** + * Find the _Fields constant that matches fieldId, or null if its not found. + */ + public static _Fields findByThriftId(int fieldId) { + switch(fieldId) { + case 1: // NESTED_INTS_COLUMN + return NESTED_INTS_COLUMN; + case 2: // NESTED_STRING_COLUMN + return NESTED_STRING_COLUMN; + default: + return null; + } + } + + /** + * Find the _Fields constant that matches fieldId, throwing an exception + * if it is not found. + */ + public static _Fields findByThriftIdOrThrow(int fieldId) { + _Fields fields = findByThriftId(fieldId); + if (fields == null) throw new IllegalArgumentException("Field " + fieldId + " doesn't exist!"); + return fields; + } + + /** + * Find the _Fields constant that matches name, or null if its not found. + */ + public static _Fields findByName(String name) { + return byName.get(name); + } + + private final short _thriftId; + private final String _fieldName; + + _Fields(short thriftId, String fieldName) { + _thriftId = thriftId; + _fieldName = fieldName; + } + + public short getThriftFieldId() { + return _thriftId; + } + + public String getFieldName() { + return _fieldName; + } + } + + // isset id assignments + public static final Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> metaDataMap; + static { + Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> tmpMap = new EnumMap<_Fields, org.apache.thrift.meta_data.FieldMetaData>(_Fields.class); + tmpMap.put(_Fields.NESTED_INTS_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("nestedIntsColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32)))); + tmpMap.put(_Fields.NESTED_STRING_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("nestedStringColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING))); + metaDataMap = Collections.unmodifiableMap(tmpMap); + org.apache.thrift.meta_data.FieldMetaData.addStructMetaDataMap(Nested.class, metaDataMap); + } + + public Nested() { + } + + public Nested( + List nestedIntsColumn, + String nestedStringColumn) + { + this(); + this.nestedIntsColumn = nestedIntsColumn; + this.nestedStringColumn = nestedStringColumn; + } + + /** + * Performs a deep copy on other. + */ + public Nested(Nested other) { + if (other.isSetNestedIntsColumn()) { + List __this__nestedIntsColumn = new ArrayList(other.nestedIntsColumn); + this.nestedIntsColumn = __this__nestedIntsColumn; + } + if (other.isSetNestedStringColumn()) { + this.nestedStringColumn = other.nestedStringColumn; + } + } + + public Nested deepCopy() { + return new Nested(this); + } + + @Override + public void clear() { + this.nestedIntsColumn = null; + this.nestedStringColumn = null; + } + + public int getNestedIntsColumnSize() { + return (this.nestedIntsColumn == null) ? 0 : this.nestedIntsColumn.size(); + } + + public java.util.Iterator getNestedIntsColumnIterator() { + return (this.nestedIntsColumn == null) ? null : this.nestedIntsColumn.iterator(); + } + + public void addToNestedIntsColumn(int elem) { + if (this.nestedIntsColumn == null) { + this.nestedIntsColumn = new ArrayList(); + } + this.nestedIntsColumn.add(elem); + } + + public List getNestedIntsColumn() { + return this.nestedIntsColumn; + } + + public Nested setNestedIntsColumn(List nestedIntsColumn) { + this.nestedIntsColumn = nestedIntsColumn; + return this; + } + + public void unsetNestedIntsColumn() { + this.nestedIntsColumn = null; + } + + /** Returns true if field nestedIntsColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetNestedIntsColumn() { + return this.nestedIntsColumn != null; + } + + public void setNestedIntsColumnIsSet(boolean value) { + if (!value) { + this.nestedIntsColumn = null; + } + } + + public String getNestedStringColumn() { + return this.nestedStringColumn; + } + + public Nested setNestedStringColumn(String nestedStringColumn) { + this.nestedStringColumn = nestedStringColumn; + return this; + } + + public void unsetNestedStringColumn() { + this.nestedStringColumn = null; + } + + /** Returns true if field nestedStringColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetNestedStringColumn() { + return this.nestedStringColumn != null; + } + + public void setNestedStringColumnIsSet(boolean value) { + if (!value) { + this.nestedStringColumn = null; + } + } + + public void setFieldValue(_Fields field, Object value) { + switch (field) { + case NESTED_INTS_COLUMN: + if (value == null) { + unsetNestedIntsColumn(); + } else { + setNestedIntsColumn((List)value); + } + break; + + case NESTED_STRING_COLUMN: + if (value == null) { + unsetNestedStringColumn(); + } else { + setNestedStringColumn((String)value); + } + break; + + } + } + + public Object getFieldValue(_Fields field) { + switch (field) { + case NESTED_INTS_COLUMN: + return getNestedIntsColumn(); + + case NESTED_STRING_COLUMN: + return getNestedStringColumn(); + + } + throw new IllegalStateException(); + } + + /** Returns true if field corresponding to fieldID is set (has been assigned a value) and false otherwise */ + public boolean isSet(_Fields field) { + if (field == null) { + throw new IllegalArgumentException(); + } + + switch (field) { + case NESTED_INTS_COLUMN: + return isSetNestedIntsColumn(); + case NESTED_STRING_COLUMN: + return isSetNestedStringColumn(); + } + throw new IllegalStateException(); + } + + @Override + public boolean equals(Object that) { + if (that == null) + return false; + if (that instanceof Nested) + return this.equals((Nested)that); + return false; + } + + public boolean equals(Nested that) { + if (that == null) + return false; + + boolean this_present_nestedIntsColumn = true && this.isSetNestedIntsColumn(); + boolean that_present_nestedIntsColumn = true && that.isSetNestedIntsColumn(); + if (this_present_nestedIntsColumn || that_present_nestedIntsColumn) { + if (!(this_present_nestedIntsColumn && that_present_nestedIntsColumn)) + return false; + if (!this.nestedIntsColumn.equals(that.nestedIntsColumn)) + return false; + } + + boolean this_present_nestedStringColumn = true && this.isSetNestedStringColumn(); + boolean that_present_nestedStringColumn = true && that.isSetNestedStringColumn(); + if (this_present_nestedStringColumn || that_present_nestedStringColumn) { + if (!(this_present_nestedStringColumn && that_present_nestedStringColumn)) + return false; + if (!this.nestedStringColumn.equals(that.nestedStringColumn)) + return false; + } + + return true; + } + + @Override + public int hashCode() { + List list = new ArrayList(); + + boolean present_nestedIntsColumn = true && (isSetNestedIntsColumn()); + list.add(present_nestedIntsColumn); + if (present_nestedIntsColumn) + list.add(nestedIntsColumn); + + boolean present_nestedStringColumn = true && (isSetNestedStringColumn()); + list.add(present_nestedStringColumn); + if (present_nestedStringColumn) + list.add(nestedStringColumn); + + return list.hashCode(); + } + + @Override + public int compareTo(Nested other) { + if (!getClass().equals(other.getClass())) { + return getClass().getName().compareTo(other.getClass().getName()); + } + + int lastComparison = 0; + + lastComparison = Boolean.valueOf(isSetNestedIntsColumn()).compareTo(other.isSetNestedIntsColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetNestedIntsColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.nestedIntsColumn, other.nestedIntsColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetNestedStringColumn()).compareTo(other.isSetNestedStringColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetNestedStringColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.nestedStringColumn, other.nestedStringColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + return 0; + } + + public _Fields fieldForId(int fieldId) { + return _Fields.findByThriftId(fieldId); + } + + public void read(org.apache.thrift.protocol.TProtocol iprot) throws org.apache.thrift.TException { + schemes.get(iprot.getScheme()).getScheme().read(iprot, this); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException { + schemes.get(oprot.getScheme()).getScheme().write(oprot, this); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("Nested("); + boolean first = true; + + sb.append("nestedIntsColumn:"); + if (this.nestedIntsColumn == null) { + sb.append("null"); + } else { + sb.append(this.nestedIntsColumn); + } + first = false; + if (!first) sb.append(", "); + sb.append("nestedStringColumn:"); + if (this.nestedStringColumn == null) { + sb.append("null"); + } else { + sb.append(this.nestedStringColumn); + } + first = false; + sb.append(")"); + return sb.toString(); + } + + public void validate() throws org.apache.thrift.TException { + // check for required fields + if (nestedIntsColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'nestedIntsColumn' was not present! Struct: " + toString()); + } + if (nestedStringColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'nestedStringColumn' was not present! Struct: " + toString()); + } + // check for sub-struct validity + } + + private void writeObject(java.io.ObjectOutputStream out) throws java.io.IOException { + try { + write(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(out))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private void readObject(java.io.ObjectInputStream in) throws java.io.IOException, ClassNotFoundException { + try { + read(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(in))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private static class NestedStandardSchemeFactory implements SchemeFactory { + public NestedStandardScheme getScheme() { + return new NestedStandardScheme(); + } + } + + private static class NestedStandardScheme extends StandardScheme { + + public void read(org.apache.thrift.protocol.TProtocol iprot, Nested struct) throws org.apache.thrift.TException { + org.apache.thrift.protocol.TField schemeField; + iprot.readStructBegin(); + while (true) + { + schemeField = iprot.readFieldBegin(); + if (schemeField.type == org.apache.thrift.protocol.TType.STOP) { + break; + } + switch (schemeField.id) { + case 1: // NESTED_INTS_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list0 = iprot.readListBegin(); + struct.nestedIntsColumn = new ArrayList(_list0.size); + int _elem1; + for (int _i2 = 0; _i2 < _list0.size; ++_i2) + { + _elem1 = iprot.readI32(); + struct.nestedIntsColumn.add(_elem1); + } + iprot.readListEnd(); + } + struct.setNestedIntsColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 2: // NESTED_STRING_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.STRING) { + struct.nestedStringColumn = iprot.readString(); + struct.setNestedStringColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + default: + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + iprot.readFieldEnd(); + } + iprot.readStructEnd(); + + // check for required fields of primitive type, which can't be checked in the validate method + struct.validate(); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot, Nested struct) throws org.apache.thrift.TException { + struct.validate(); + + oprot.writeStructBegin(STRUCT_DESC); + if (struct.nestedIntsColumn != null) { + oprot.writeFieldBegin(NESTED_INTS_COLUMN_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.I32, struct.nestedIntsColumn.size())); + for (int _iter3 : struct.nestedIntsColumn) + { + oprot.writeI32(_iter3); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.nestedStringColumn != null) { + oprot.writeFieldBegin(NESTED_STRING_COLUMN_FIELD_DESC); + oprot.writeString(struct.nestedStringColumn); + oprot.writeFieldEnd(); + } + oprot.writeFieldStop(); + oprot.writeStructEnd(); + } + + } + + private static class NestedTupleSchemeFactory implements SchemeFactory { + public NestedTupleScheme getScheme() { + return new NestedTupleScheme(); + } + } + + private static class NestedTupleScheme extends TupleScheme { + + @Override + public void write(org.apache.thrift.protocol.TProtocol prot, Nested struct) throws org.apache.thrift.TException { + TTupleProtocol oprot = (TTupleProtocol) prot; + { + oprot.writeI32(struct.nestedIntsColumn.size()); + for (int _iter4 : struct.nestedIntsColumn) + { + oprot.writeI32(_iter4); + } + } + oprot.writeString(struct.nestedStringColumn); + } + + @Override + public void read(org.apache.thrift.protocol.TProtocol prot, Nested struct) throws org.apache.thrift.TException { + TTupleProtocol iprot = (TTupleProtocol) prot; + { + org.apache.thrift.protocol.TList _list5 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.I32, iprot.readI32()); + struct.nestedIntsColumn = new ArrayList(_list5.size); + int _elem6; + for (int _i7 = 0; _i7 < _list5.size; ++_i7) + { + _elem6 = iprot.readI32(); + struct.nestedIntsColumn.add(_elem6); + } + } + struct.setNestedIntsColumnIsSet(true); + struct.nestedStringColumn = iprot.readString(); + struct.setNestedStringColumnIsSet(true); + } + } + +} + diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/ParquetThriftCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/ParquetThriftCompat.java new file mode 100644 index 0000000000000..326ae9dbaa0d1 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/ParquetThriftCompat.java @@ -0,0 +1,2808 @@ +/** + * Autogenerated by Thrift Compiler (0.9.2) + * + * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING + * @generated + */ +package org.apache.spark.sql.parquet.test.thrift; + +import org.apache.thrift.scheme.IScheme; +import org.apache.thrift.scheme.SchemeFactory; +import org.apache.thrift.scheme.StandardScheme; + +import org.apache.thrift.scheme.TupleScheme; +import org.apache.thrift.protocol.TTupleProtocol; +import org.apache.thrift.protocol.TProtocolException; +import org.apache.thrift.EncodingUtils; +import org.apache.thrift.TException; +import org.apache.thrift.async.AsyncMethodCallback; +import org.apache.thrift.server.AbstractNonblockingServer.*; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.HashMap; +import java.util.EnumMap; +import java.util.Set; +import java.util.HashSet; +import java.util.EnumSet; +import java.util.Collections; +import java.util.BitSet; +import java.nio.ByteBuffer; +import java.util.Arrays; +import javax.annotation.Generated; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@SuppressWarnings({"cast", "rawtypes", "serial", "unchecked"}) +/** + * This is a test struct for testing parquet-thrift compatibility. + */ +@Generated(value = "Autogenerated by Thrift Compiler (0.9.2)", date = "2015-7-7") +public class ParquetThriftCompat implements org.apache.thrift.TBase, java.io.Serializable, Cloneable, Comparable { + private static final org.apache.thrift.protocol.TStruct STRUCT_DESC = new org.apache.thrift.protocol.TStruct("ParquetThriftCompat"); + + private static final org.apache.thrift.protocol.TField BOOL_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("boolColumn", org.apache.thrift.protocol.TType.BOOL, (short)1); + private static final org.apache.thrift.protocol.TField BYTE_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("byteColumn", org.apache.thrift.protocol.TType.BYTE, (short)2); + private static final org.apache.thrift.protocol.TField SHORT_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("shortColumn", org.apache.thrift.protocol.TType.I16, (short)3); + private static final org.apache.thrift.protocol.TField INT_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("intColumn", org.apache.thrift.protocol.TType.I32, (short)4); + private static final org.apache.thrift.protocol.TField LONG_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("longColumn", org.apache.thrift.protocol.TType.I64, (short)5); + private static final org.apache.thrift.protocol.TField DOUBLE_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("doubleColumn", org.apache.thrift.protocol.TType.DOUBLE, (short)6); + private static final org.apache.thrift.protocol.TField BINARY_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("binaryColumn", org.apache.thrift.protocol.TType.STRING, (short)7); + private static final org.apache.thrift.protocol.TField STRING_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("stringColumn", org.apache.thrift.protocol.TType.STRING, (short)8); + private static final org.apache.thrift.protocol.TField ENUM_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("enumColumn", org.apache.thrift.protocol.TType.I32, (short)9); + private static final org.apache.thrift.protocol.TField MAYBE_BOOL_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeBoolColumn", org.apache.thrift.protocol.TType.BOOL, (short)10); + private static final org.apache.thrift.protocol.TField MAYBE_BYTE_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeByteColumn", org.apache.thrift.protocol.TType.BYTE, (short)11); + private static final org.apache.thrift.protocol.TField MAYBE_SHORT_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeShortColumn", org.apache.thrift.protocol.TType.I16, (short)12); + private static final org.apache.thrift.protocol.TField MAYBE_INT_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeIntColumn", org.apache.thrift.protocol.TType.I32, (short)13); + private static final org.apache.thrift.protocol.TField MAYBE_LONG_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeLongColumn", org.apache.thrift.protocol.TType.I64, (short)14); + private static final org.apache.thrift.protocol.TField MAYBE_DOUBLE_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeDoubleColumn", org.apache.thrift.protocol.TType.DOUBLE, (short)15); + private static final org.apache.thrift.protocol.TField MAYBE_BINARY_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeBinaryColumn", org.apache.thrift.protocol.TType.STRING, (short)16); + private static final org.apache.thrift.protocol.TField MAYBE_STRING_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeStringColumn", org.apache.thrift.protocol.TType.STRING, (short)17); + private static final org.apache.thrift.protocol.TField MAYBE_ENUM_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("maybeEnumColumn", org.apache.thrift.protocol.TType.I32, (short)18); + private static final org.apache.thrift.protocol.TField STRINGS_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("stringsColumn", org.apache.thrift.protocol.TType.LIST, (short)19); + private static final org.apache.thrift.protocol.TField INT_SET_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("intSetColumn", org.apache.thrift.protocol.TType.SET, (short)20); + private static final org.apache.thrift.protocol.TField INT_TO_STRING_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("intToStringColumn", org.apache.thrift.protocol.TType.MAP, (short)21); + private static final org.apache.thrift.protocol.TField COMPLEX_COLUMN_FIELD_DESC = new org.apache.thrift.protocol.TField("complexColumn", org.apache.thrift.protocol.TType.MAP, (short)22); + + private static final Map, SchemeFactory> schemes = new HashMap, SchemeFactory>(); + static { + schemes.put(StandardScheme.class, new ParquetThriftCompatStandardSchemeFactory()); + schemes.put(TupleScheme.class, new ParquetThriftCompatTupleSchemeFactory()); + } + + public boolean boolColumn; // required + public byte byteColumn; // required + public short shortColumn; // required + public int intColumn; // required + public long longColumn; // required + public double doubleColumn; // required + public ByteBuffer binaryColumn; // required + public String stringColumn; // required + /** + * + * @see Suit + */ + public Suit enumColumn; // required + public boolean maybeBoolColumn; // optional + public byte maybeByteColumn; // optional + public short maybeShortColumn; // optional + public int maybeIntColumn; // optional + public long maybeLongColumn; // optional + public double maybeDoubleColumn; // optional + public ByteBuffer maybeBinaryColumn; // optional + public String maybeStringColumn; // optional + /** + * + * @see Suit + */ + public Suit maybeEnumColumn; // optional + public List stringsColumn; // required + public Set intSetColumn; // required + public Map intToStringColumn; // required + public Map> complexColumn; // required + + /** The set of fields this struct contains, along with convenience methods for finding and manipulating them. */ + public enum _Fields implements org.apache.thrift.TFieldIdEnum { + BOOL_COLUMN((short)1, "boolColumn"), + BYTE_COLUMN((short)2, "byteColumn"), + SHORT_COLUMN((short)3, "shortColumn"), + INT_COLUMN((short)4, "intColumn"), + LONG_COLUMN((short)5, "longColumn"), + DOUBLE_COLUMN((short)6, "doubleColumn"), + BINARY_COLUMN((short)7, "binaryColumn"), + STRING_COLUMN((short)8, "stringColumn"), + /** + * + * @see Suit + */ + ENUM_COLUMN((short)9, "enumColumn"), + MAYBE_BOOL_COLUMN((short)10, "maybeBoolColumn"), + MAYBE_BYTE_COLUMN((short)11, "maybeByteColumn"), + MAYBE_SHORT_COLUMN((short)12, "maybeShortColumn"), + MAYBE_INT_COLUMN((short)13, "maybeIntColumn"), + MAYBE_LONG_COLUMN((short)14, "maybeLongColumn"), + MAYBE_DOUBLE_COLUMN((short)15, "maybeDoubleColumn"), + MAYBE_BINARY_COLUMN((short)16, "maybeBinaryColumn"), + MAYBE_STRING_COLUMN((short)17, "maybeStringColumn"), + /** + * + * @see Suit + */ + MAYBE_ENUM_COLUMN((short)18, "maybeEnumColumn"), + STRINGS_COLUMN((short)19, "stringsColumn"), + INT_SET_COLUMN((short)20, "intSetColumn"), + INT_TO_STRING_COLUMN((short)21, "intToStringColumn"), + COMPLEX_COLUMN((short)22, "complexColumn"); + + private static final Map byName = new HashMap(); + + static { + for (_Fields field : EnumSet.allOf(_Fields.class)) { + byName.put(field.getFieldName(), field); + } + } + + /** + * Find the _Fields constant that matches fieldId, or null if its not found. + */ + public static _Fields findByThriftId(int fieldId) { + switch(fieldId) { + case 1: // BOOL_COLUMN + return BOOL_COLUMN; + case 2: // BYTE_COLUMN + return BYTE_COLUMN; + case 3: // SHORT_COLUMN + return SHORT_COLUMN; + case 4: // INT_COLUMN + return INT_COLUMN; + case 5: // LONG_COLUMN + return LONG_COLUMN; + case 6: // DOUBLE_COLUMN + return DOUBLE_COLUMN; + case 7: // BINARY_COLUMN + return BINARY_COLUMN; + case 8: // STRING_COLUMN + return STRING_COLUMN; + case 9: // ENUM_COLUMN + return ENUM_COLUMN; + case 10: // MAYBE_BOOL_COLUMN + return MAYBE_BOOL_COLUMN; + case 11: // MAYBE_BYTE_COLUMN + return MAYBE_BYTE_COLUMN; + case 12: // MAYBE_SHORT_COLUMN + return MAYBE_SHORT_COLUMN; + case 13: // MAYBE_INT_COLUMN + return MAYBE_INT_COLUMN; + case 14: // MAYBE_LONG_COLUMN + return MAYBE_LONG_COLUMN; + case 15: // MAYBE_DOUBLE_COLUMN + return MAYBE_DOUBLE_COLUMN; + case 16: // MAYBE_BINARY_COLUMN + return MAYBE_BINARY_COLUMN; + case 17: // MAYBE_STRING_COLUMN + return MAYBE_STRING_COLUMN; + case 18: // MAYBE_ENUM_COLUMN + return MAYBE_ENUM_COLUMN; + case 19: // STRINGS_COLUMN + return STRINGS_COLUMN; + case 20: // INT_SET_COLUMN + return INT_SET_COLUMN; + case 21: // INT_TO_STRING_COLUMN + return INT_TO_STRING_COLUMN; + case 22: // COMPLEX_COLUMN + return COMPLEX_COLUMN; + default: + return null; + } + } + + /** + * Find the _Fields constant that matches fieldId, throwing an exception + * if it is not found. + */ + public static _Fields findByThriftIdOrThrow(int fieldId) { + _Fields fields = findByThriftId(fieldId); + if (fields == null) throw new IllegalArgumentException("Field " + fieldId + " doesn't exist!"); + return fields; + } + + /** + * Find the _Fields constant that matches name, or null if its not found. + */ + public static _Fields findByName(String name) { + return byName.get(name); + } + + private final short _thriftId; + private final String _fieldName; + + _Fields(short thriftId, String fieldName) { + _thriftId = thriftId; + _fieldName = fieldName; + } + + public short getThriftFieldId() { + return _thriftId; + } + + public String getFieldName() { + return _fieldName; + } + } + + // isset id assignments + private static final int __BOOLCOLUMN_ISSET_ID = 0; + private static final int __BYTECOLUMN_ISSET_ID = 1; + private static final int __SHORTCOLUMN_ISSET_ID = 2; + private static final int __INTCOLUMN_ISSET_ID = 3; + private static final int __LONGCOLUMN_ISSET_ID = 4; + private static final int __DOUBLECOLUMN_ISSET_ID = 5; + private static final int __MAYBEBOOLCOLUMN_ISSET_ID = 6; + private static final int __MAYBEBYTECOLUMN_ISSET_ID = 7; + private static final int __MAYBESHORTCOLUMN_ISSET_ID = 8; + private static final int __MAYBEINTCOLUMN_ISSET_ID = 9; + private static final int __MAYBELONGCOLUMN_ISSET_ID = 10; + private static final int __MAYBEDOUBLECOLUMN_ISSET_ID = 11; + private short __isset_bitfield = 0; + private static final _Fields optionals[] = {_Fields.MAYBE_BOOL_COLUMN,_Fields.MAYBE_BYTE_COLUMN,_Fields.MAYBE_SHORT_COLUMN,_Fields.MAYBE_INT_COLUMN,_Fields.MAYBE_LONG_COLUMN,_Fields.MAYBE_DOUBLE_COLUMN,_Fields.MAYBE_BINARY_COLUMN,_Fields.MAYBE_STRING_COLUMN,_Fields.MAYBE_ENUM_COLUMN}; + public static final Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> metaDataMap; + static { + Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> tmpMap = new EnumMap<_Fields, org.apache.thrift.meta_data.FieldMetaData>(_Fields.class); + tmpMap.put(_Fields.BOOL_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("boolColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.BOOL))); + tmpMap.put(_Fields.BYTE_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("byteColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.BYTE))); + tmpMap.put(_Fields.SHORT_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("shortColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I16))); + tmpMap.put(_Fields.INT_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("intColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32))); + tmpMap.put(_Fields.LONG_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("longColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I64))); + tmpMap.put(_Fields.DOUBLE_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("doubleColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.DOUBLE))); + tmpMap.put(_Fields.BINARY_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("binaryColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING , true))); + tmpMap.put(_Fields.STRING_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("stringColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING))); + tmpMap.put(_Fields.ENUM_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("enumColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.EnumMetaData(org.apache.thrift.protocol.TType.ENUM, Suit.class))); + tmpMap.put(_Fields.MAYBE_BOOL_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeBoolColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.BOOL))); + tmpMap.put(_Fields.MAYBE_BYTE_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeByteColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.BYTE))); + tmpMap.put(_Fields.MAYBE_SHORT_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeShortColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I16))); + tmpMap.put(_Fields.MAYBE_INT_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeIntColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32))); + tmpMap.put(_Fields.MAYBE_LONG_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeLongColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I64))); + tmpMap.put(_Fields.MAYBE_DOUBLE_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeDoubleColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.DOUBLE))); + tmpMap.put(_Fields.MAYBE_BINARY_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeBinaryColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING , true))); + tmpMap.put(_Fields.MAYBE_STRING_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeStringColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING))); + tmpMap.put(_Fields.MAYBE_ENUM_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("maybeEnumColumn", org.apache.thrift.TFieldRequirementType.OPTIONAL, + new org.apache.thrift.meta_data.EnumMetaData(org.apache.thrift.protocol.TType.ENUM, Suit.class))); + tmpMap.put(_Fields.STRINGS_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("stringsColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING)))); + tmpMap.put(_Fields.INT_SET_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("intSetColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.SetMetaData(org.apache.thrift.protocol.TType.SET, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32)))); + tmpMap.put(_Fields.INT_TO_STRING_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("intToStringColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.MapMetaData(org.apache.thrift.protocol.TType.MAP, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32), + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING)))); + tmpMap.put(_Fields.COMPLEX_COLUMN, new org.apache.thrift.meta_data.FieldMetaData("complexColumn", org.apache.thrift.TFieldRequirementType.REQUIRED, + new org.apache.thrift.meta_data.MapMetaData(org.apache.thrift.protocol.TType.MAP, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32), + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.StructMetaData(org.apache.thrift.protocol.TType.STRUCT, Nested.class))))); + metaDataMap = Collections.unmodifiableMap(tmpMap); + org.apache.thrift.meta_data.FieldMetaData.addStructMetaDataMap(ParquetThriftCompat.class, metaDataMap); + } + + public ParquetThriftCompat() { + } + + public ParquetThriftCompat( + boolean boolColumn, + byte byteColumn, + short shortColumn, + int intColumn, + long longColumn, + double doubleColumn, + ByteBuffer binaryColumn, + String stringColumn, + Suit enumColumn, + List stringsColumn, + Set intSetColumn, + Map intToStringColumn, + Map> complexColumn) + { + this(); + this.boolColumn = boolColumn; + setBoolColumnIsSet(true); + this.byteColumn = byteColumn; + setByteColumnIsSet(true); + this.shortColumn = shortColumn; + setShortColumnIsSet(true); + this.intColumn = intColumn; + setIntColumnIsSet(true); + this.longColumn = longColumn; + setLongColumnIsSet(true); + this.doubleColumn = doubleColumn; + setDoubleColumnIsSet(true); + this.binaryColumn = org.apache.thrift.TBaseHelper.copyBinary(binaryColumn); + this.stringColumn = stringColumn; + this.enumColumn = enumColumn; + this.stringsColumn = stringsColumn; + this.intSetColumn = intSetColumn; + this.intToStringColumn = intToStringColumn; + this.complexColumn = complexColumn; + } + + /** + * Performs a deep copy on other. + */ + public ParquetThriftCompat(ParquetThriftCompat other) { + __isset_bitfield = other.__isset_bitfield; + this.boolColumn = other.boolColumn; + this.byteColumn = other.byteColumn; + this.shortColumn = other.shortColumn; + this.intColumn = other.intColumn; + this.longColumn = other.longColumn; + this.doubleColumn = other.doubleColumn; + if (other.isSetBinaryColumn()) { + this.binaryColumn = org.apache.thrift.TBaseHelper.copyBinary(other.binaryColumn); + } + if (other.isSetStringColumn()) { + this.stringColumn = other.stringColumn; + } + if (other.isSetEnumColumn()) { + this.enumColumn = other.enumColumn; + } + this.maybeBoolColumn = other.maybeBoolColumn; + this.maybeByteColumn = other.maybeByteColumn; + this.maybeShortColumn = other.maybeShortColumn; + this.maybeIntColumn = other.maybeIntColumn; + this.maybeLongColumn = other.maybeLongColumn; + this.maybeDoubleColumn = other.maybeDoubleColumn; + if (other.isSetMaybeBinaryColumn()) { + this.maybeBinaryColumn = org.apache.thrift.TBaseHelper.copyBinary(other.maybeBinaryColumn); + } + if (other.isSetMaybeStringColumn()) { + this.maybeStringColumn = other.maybeStringColumn; + } + if (other.isSetMaybeEnumColumn()) { + this.maybeEnumColumn = other.maybeEnumColumn; + } + if (other.isSetStringsColumn()) { + List __this__stringsColumn = new ArrayList(other.stringsColumn); + this.stringsColumn = __this__stringsColumn; + } + if (other.isSetIntSetColumn()) { + Set __this__intSetColumn = new HashSet(other.intSetColumn); + this.intSetColumn = __this__intSetColumn; + } + if (other.isSetIntToStringColumn()) { + Map __this__intToStringColumn = new HashMap(other.intToStringColumn); + this.intToStringColumn = __this__intToStringColumn; + } + if (other.isSetComplexColumn()) { + Map> __this__complexColumn = new HashMap>(other.complexColumn.size()); + for (Map.Entry> other_element : other.complexColumn.entrySet()) { + + Integer other_element_key = other_element.getKey(); + List other_element_value = other_element.getValue(); + + Integer __this__complexColumn_copy_key = other_element_key; + + List __this__complexColumn_copy_value = new ArrayList(other_element_value.size()); + for (Nested other_element_value_element : other_element_value) { + __this__complexColumn_copy_value.add(new Nested(other_element_value_element)); + } + + __this__complexColumn.put(__this__complexColumn_copy_key, __this__complexColumn_copy_value); + } + this.complexColumn = __this__complexColumn; + } + } + + public ParquetThriftCompat deepCopy() { + return new ParquetThriftCompat(this); + } + + @Override + public void clear() { + setBoolColumnIsSet(false); + this.boolColumn = false; + setByteColumnIsSet(false); + this.byteColumn = 0; + setShortColumnIsSet(false); + this.shortColumn = 0; + setIntColumnIsSet(false); + this.intColumn = 0; + setLongColumnIsSet(false); + this.longColumn = 0; + setDoubleColumnIsSet(false); + this.doubleColumn = 0.0; + this.binaryColumn = null; + this.stringColumn = null; + this.enumColumn = null; + setMaybeBoolColumnIsSet(false); + this.maybeBoolColumn = false; + setMaybeByteColumnIsSet(false); + this.maybeByteColumn = 0; + setMaybeShortColumnIsSet(false); + this.maybeShortColumn = 0; + setMaybeIntColumnIsSet(false); + this.maybeIntColumn = 0; + setMaybeLongColumnIsSet(false); + this.maybeLongColumn = 0; + setMaybeDoubleColumnIsSet(false); + this.maybeDoubleColumn = 0.0; + this.maybeBinaryColumn = null; + this.maybeStringColumn = null; + this.maybeEnumColumn = null; + this.stringsColumn = null; + this.intSetColumn = null; + this.intToStringColumn = null; + this.complexColumn = null; + } + + public boolean isBoolColumn() { + return this.boolColumn; + } + + public ParquetThriftCompat setBoolColumn(boolean boolColumn) { + this.boolColumn = boolColumn; + setBoolColumnIsSet(true); + return this; + } + + public void unsetBoolColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __BOOLCOLUMN_ISSET_ID); + } + + /** Returns true if field boolColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetBoolColumn() { + return EncodingUtils.testBit(__isset_bitfield, __BOOLCOLUMN_ISSET_ID); + } + + public void setBoolColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __BOOLCOLUMN_ISSET_ID, value); + } + + public byte getByteColumn() { + return this.byteColumn; + } + + public ParquetThriftCompat setByteColumn(byte byteColumn) { + this.byteColumn = byteColumn; + setByteColumnIsSet(true); + return this; + } + + public void unsetByteColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __BYTECOLUMN_ISSET_ID); + } + + /** Returns true if field byteColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetByteColumn() { + return EncodingUtils.testBit(__isset_bitfield, __BYTECOLUMN_ISSET_ID); + } + + public void setByteColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __BYTECOLUMN_ISSET_ID, value); + } + + public short getShortColumn() { + return this.shortColumn; + } + + public ParquetThriftCompat setShortColumn(short shortColumn) { + this.shortColumn = shortColumn; + setShortColumnIsSet(true); + return this; + } + + public void unsetShortColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __SHORTCOLUMN_ISSET_ID); + } + + /** Returns true if field shortColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetShortColumn() { + return EncodingUtils.testBit(__isset_bitfield, __SHORTCOLUMN_ISSET_ID); + } + + public void setShortColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __SHORTCOLUMN_ISSET_ID, value); + } + + public int getIntColumn() { + return this.intColumn; + } + + public ParquetThriftCompat setIntColumn(int intColumn) { + this.intColumn = intColumn; + setIntColumnIsSet(true); + return this; + } + + public void unsetIntColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __INTCOLUMN_ISSET_ID); + } + + /** Returns true if field intColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetIntColumn() { + return EncodingUtils.testBit(__isset_bitfield, __INTCOLUMN_ISSET_ID); + } + + public void setIntColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __INTCOLUMN_ISSET_ID, value); + } + + public long getLongColumn() { + return this.longColumn; + } + + public ParquetThriftCompat setLongColumn(long longColumn) { + this.longColumn = longColumn; + setLongColumnIsSet(true); + return this; + } + + public void unsetLongColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __LONGCOLUMN_ISSET_ID); + } + + /** Returns true if field longColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetLongColumn() { + return EncodingUtils.testBit(__isset_bitfield, __LONGCOLUMN_ISSET_ID); + } + + public void setLongColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __LONGCOLUMN_ISSET_ID, value); + } + + public double getDoubleColumn() { + return this.doubleColumn; + } + + public ParquetThriftCompat setDoubleColumn(double doubleColumn) { + this.doubleColumn = doubleColumn; + setDoubleColumnIsSet(true); + return this; + } + + public void unsetDoubleColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __DOUBLECOLUMN_ISSET_ID); + } + + /** Returns true if field doubleColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetDoubleColumn() { + return EncodingUtils.testBit(__isset_bitfield, __DOUBLECOLUMN_ISSET_ID); + } + + public void setDoubleColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __DOUBLECOLUMN_ISSET_ID, value); + } + + public byte[] getBinaryColumn() { + setBinaryColumn(org.apache.thrift.TBaseHelper.rightSize(binaryColumn)); + return binaryColumn == null ? null : binaryColumn.array(); + } + + public ByteBuffer bufferForBinaryColumn() { + return org.apache.thrift.TBaseHelper.copyBinary(binaryColumn); + } + + public ParquetThriftCompat setBinaryColumn(byte[] binaryColumn) { + this.binaryColumn = binaryColumn == null ? (ByteBuffer)null : ByteBuffer.wrap(Arrays.copyOf(binaryColumn, binaryColumn.length)); + return this; + } + + public ParquetThriftCompat setBinaryColumn(ByteBuffer binaryColumn) { + this.binaryColumn = org.apache.thrift.TBaseHelper.copyBinary(binaryColumn); + return this; + } + + public void unsetBinaryColumn() { + this.binaryColumn = null; + } + + /** Returns true if field binaryColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetBinaryColumn() { + return this.binaryColumn != null; + } + + public void setBinaryColumnIsSet(boolean value) { + if (!value) { + this.binaryColumn = null; + } + } + + public String getStringColumn() { + return this.stringColumn; + } + + public ParquetThriftCompat setStringColumn(String stringColumn) { + this.stringColumn = stringColumn; + return this; + } + + public void unsetStringColumn() { + this.stringColumn = null; + } + + /** Returns true if field stringColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetStringColumn() { + return this.stringColumn != null; + } + + public void setStringColumnIsSet(boolean value) { + if (!value) { + this.stringColumn = null; + } + } + + /** + * + * @see Suit + */ + public Suit getEnumColumn() { + return this.enumColumn; + } + + /** + * + * @see Suit + */ + public ParquetThriftCompat setEnumColumn(Suit enumColumn) { + this.enumColumn = enumColumn; + return this; + } + + public void unsetEnumColumn() { + this.enumColumn = null; + } + + /** Returns true if field enumColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetEnumColumn() { + return this.enumColumn != null; + } + + public void setEnumColumnIsSet(boolean value) { + if (!value) { + this.enumColumn = null; + } + } + + public boolean isMaybeBoolColumn() { + return this.maybeBoolColumn; + } + + public ParquetThriftCompat setMaybeBoolColumn(boolean maybeBoolColumn) { + this.maybeBoolColumn = maybeBoolColumn; + setMaybeBoolColumnIsSet(true); + return this; + } + + public void unsetMaybeBoolColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __MAYBEBOOLCOLUMN_ISSET_ID); + } + + /** Returns true if field maybeBoolColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeBoolColumn() { + return EncodingUtils.testBit(__isset_bitfield, __MAYBEBOOLCOLUMN_ISSET_ID); + } + + public void setMaybeBoolColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __MAYBEBOOLCOLUMN_ISSET_ID, value); + } + + public byte getMaybeByteColumn() { + return this.maybeByteColumn; + } + + public ParquetThriftCompat setMaybeByteColumn(byte maybeByteColumn) { + this.maybeByteColumn = maybeByteColumn; + setMaybeByteColumnIsSet(true); + return this; + } + + public void unsetMaybeByteColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __MAYBEBYTECOLUMN_ISSET_ID); + } + + /** Returns true if field maybeByteColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeByteColumn() { + return EncodingUtils.testBit(__isset_bitfield, __MAYBEBYTECOLUMN_ISSET_ID); + } + + public void setMaybeByteColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __MAYBEBYTECOLUMN_ISSET_ID, value); + } + + public short getMaybeShortColumn() { + return this.maybeShortColumn; + } + + public ParquetThriftCompat setMaybeShortColumn(short maybeShortColumn) { + this.maybeShortColumn = maybeShortColumn; + setMaybeShortColumnIsSet(true); + return this; + } + + public void unsetMaybeShortColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __MAYBESHORTCOLUMN_ISSET_ID); + } + + /** Returns true if field maybeShortColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeShortColumn() { + return EncodingUtils.testBit(__isset_bitfield, __MAYBESHORTCOLUMN_ISSET_ID); + } + + public void setMaybeShortColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __MAYBESHORTCOLUMN_ISSET_ID, value); + } + + public int getMaybeIntColumn() { + return this.maybeIntColumn; + } + + public ParquetThriftCompat setMaybeIntColumn(int maybeIntColumn) { + this.maybeIntColumn = maybeIntColumn; + setMaybeIntColumnIsSet(true); + return this; + } + + public void unsetMaybeIntColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __MAYBEINTCOLUMN_ISSET_ID); + } + + /** Returns true if field maybeIntColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeIntColumn() { + return EncodingUtils.testBit(__isset_bitfield, __MAYBEINTCOLUMN_ISSET_ID); + } + + public void setMaybeIntColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __MAYBEINTCOLUMN_ISSET_ID, value); + } + + public long getMaybeLongColumn() { + return this.maybeLongColumn; + } + + public ParquetThriftCompat setMaybeLongColumn(long maybeLongColumn) { + this.maybeLongColumn = maybeLongColumn; + setMaybeLongColumnIsSet(true); + return this; + } + + public void unsetMaybeLongColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __MAYBELONGCOLUMN_ISSET_ID); + } + + /** Returns true if field maybeLongColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeLongColumn() { + return EncodingUtils.testBit(__isset_bitfield, __MAYBELONGCOLUMN_ISSET_ID); + } + + public void setMaybeLongColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __MAYBELONGCOLUMN_ISSET_ID, value); + } + + public double getMaybeDoubleColumn() { + return this.maybeDoubleColumn; + } + + public ParquetThriftCompat setMaybeDoubleColumn(double maybeDoubleColumn) { + this.maybeDoubleColumn = maybeDoubleColumn; + setMaybeDoubleColumnIsSet(true); + return this; + } + + public void unsetMaybeDoubleColumn() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __MAYBEDOUBLECOLUMN_ISSET_ID); + } + + /** Returns true if field maybeDoubleColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeDoubleColumn() { + return EncodingUtils.testBit(__isset_bitfield, __MAYBEDOUBLECOLUMN_ISSET_ID); + } + + public void setMaybeDoubleColumnIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __MAYBEDOUBLECOLUMN_ISSET_ID, value); + } + + public byte[] getMaybeBinaryColumn() { + setMaybeBinaryColumn(org.apache.thrift.TBaseHelper.rightSize(maybeBinaryColumn)); + return maybeBinaryColumn == null ? null : maybeBinaryColumn.array(); + } + + public ByteBuffer bufferForMaybeBinaryColumn() { + return org.apache.thrift.TBaseHelper.copyBinary(maybeBinaryColumn); + } + + public ParquetThriftCompat setMaybeBinaryColumn(byte[] maybeBinaryColumn) { + this.maybeBinaryColumn = maybeBinaryColumn == null ? (ByteBuffer)null : ByteBuffer.wrap(Arrays.copyOf(maybeBinaryColumn, maybeBinaryColumn.length)); + return this; + } + + public ParquetThriftCompat setMaybeBinaryColumn(ByteBuffer maybeBinaryColumn) { + this.maybeBinaryColumn = org.apache.thrift.TBaseHelper.copyBinary(maybeBinaryColumn); + return this; + } + + public void unsetMaybeBinaryColumn() { + this.maybeBinaryColumn = null; + } + + /** Returns true if field maybeBinaryColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeBinaryColumn() { + return this.maybeBinaryColumn != null; + } + + public void setMaybeBinaryColumnIsSet(boolean value) { + if (!value) { + this.maybeBinaryColumn = null; + } + } + + public String getMaybeStringColumn() { + return this.maybeStringColumn; + } + + public ParquetThriftCompat setMaybeStringColumn(String maybeStringColumn) { + this.maybeStringColumn = maybeStringColumn; + return this; + } + + public void unsetMaybeStringColumn() { + this.maybeStringColumn = null; + } + + /** Returns true if field maybeStringColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeStringColumn() { + return this.maybeStringColumn != null; + } + + public void setMaybeStringColumnIsSet(boolean value) { + if (!value) { + this.maybeStringColumn = null; + } + } + + /** + * + * @see Suit + */ + public Suit getMaybeEnumColumn() { + return this.maybeEnumColumn; + } + + /** + * + * @see Suit + */ + public ParquetThriftCompat setMaybeEnumColumn(Suit maybeEnumColumn) { + this.maybeEnumColumn = maybeEnumColumn; + return this; + } + + public void unsetMaybeEnumColumn() { + this.maybeEnumColumn = null; + } + + /** Returns true if field maybeEnumColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetMaybeEnumColumn() { + return this.maybeEnumColumn != null; + } + + public void setMaybeEnumColumnIsSet(boolean value) { + if (!value) { + this.maybeEnumColumn = null; + } + } + + public int getStringsColumnSize() { + return (this.stringsColumn == null) ? 0 : this.stringsColumn.size(); + } + + public java.util.Iterator getStringsColumnIterator() { + return (this.stringsColumn == null) ? null : this.stringsColumn.iterator(); + } + + public void addToStringsColumn(String elem) { + if (this.stringsColumn == null) { + this.stringsColumn = new ArrayList(); + } + this.stringsColumn.add(elem); + } + + public List getStringsColumn() { + return this.stringsColumn; + } + + public ParquetThriftCompat setStringsColumn(List stringsColumn) { + this.stringsColumn = stringsColumn; + return this; + } + + public void unsetStringsColumn() { + this.stringsColumn = null; + } + + /** Returns true if field stringsColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetStringsColumn() { + return this.stringsColumn != null; + } + + public void setStringsColumnIsSet(boolean value) { + if (!value) { + this.stringsColumn = null; + } + } + + public int getIntSetColumnSize() { + return (this.intSetColumn == null) ? 0 : this.intSetColumn.size(); + } + + public java.util.Iterator getIntSetColumnIterator() { + return (this.intSetColumn == null) ? null : this.intSetColumn.iterator(); + } + + public void addToIntSetColumn(int elem) { + if (this.intSetColumn == null) { + this.intSetColumn = new HashSet(); + } + this.intSetColumn.add(elem); + } + + public Set getIntSetColumn() { + return this.intSetColumn; + } + + public ParquetThriftCompat setIntSetColumn(Set intSetColumn) { + this.intSetColumn = intSetColumn; + return this; + } + + public void unsetIntSetColumn() { + this.intSetColumn = null; + } + + /** Returns true if field intSetColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetIntSetColumn() { + return this.intSetColumn != null; + } + + public void setIntSetColumnIsSet(boolean value) { + if (!value) { + this.intSetColumn = null; + } + } + + public int getIntToStringColumnSize() { + return (this.intToStringColumn == null) ? 0 : this.intToStringColumn.size(); + } + + public void putToIntToStringColumn(int key, String val) { + if (this.intToStringColumn == null) { + this.intToStringColumn = new HashMap(); + } + this.intToStringColumn.put(key, val); + } + + public Map getIntToStringColumn() { + return this.intToStringColumn; + } + + public ParquetThriftCompat setIntToStringColumn(Map intToStringColumn) { + this.intToStringColumn = intToStringColumn; + return this; + } + + public void unsetIntToStringColumn() { + this.intToStringColumn = null; + } + + /** Returns true if field intToStringColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetIntToStringColumn() { + return this.intToStringColumn != null; + } + + public void setIntToStringColumnIsSet(boolean value) { + if (!value) { + this.intToStringColumn = null; + } + } + + public int getComplexColumnSize() { + return (this.complexColumn == null) ? 0 : this.complexColumn.size(); + } + + public void putToComplexColumn(int key, List val) { + if (this.complexColumn == null) { + this.complexColumn = new HashMap>(); + } + this.complexColumn.put(key, val); + } + + public Map> getComplexColumn() { + return this.complexColumn; + } + + public ParquetThriftCompat setComplexColumn(Map> complexColumn) { + this.complexColumn = complexColumn; + return this; + } + + public void unsetComplexColumn() { + this.complexColumn = null; + } + + /** Returns true if field complexColumn is set (has been assigned a value) and false otherwise */ + public boolean isSetComplexColumn() { + return this.complexColumn != null; + } + + public void setComplexColumnIsSet(boolean value) { + if (!value) { + this.complexColumn = null; + } + } + + public void setFieldValue(_Fields field, Object value) { + switch (field) { + case BOOL_COLUMN: + if (value == null) { + unsetBoolColumn(); + } else { + setBoolColumn((Boolean)value); + } + break; + + case BYTE_COLUMN: + if (value == null) { + unsetByteColumn(); + } else { + setByteColumn((Byte)value); + } + break; + + case SHORT_COLUMN: + if (value == null) { + unsetShortColumn(); + } else { + setShortColumn((Short)value); + } + break; + + case INT_COLUMN: + if (value == null) { + unsetIntColumn(); + } else { + setIntColumn((Integer)value); + } + break; + + case LONG_COLUMN: + if (value == null) { + unsetLongColumn(); + } else { + setLongColumn((Long)value); + } + break; + + case DOUBLE_COLUMN: + if (value == null) { + unsetDoubleColumn(); + } else { + setDoubleColumn((Double)value); + } + break; + + case BINARY_COLUMN: + if (value == null) { + unsetBinaryColumn(); + } else { + setBinaryColumn((ByteBuffer)value); + } + break; + + case STRING_COLUMN: + if (value == null) { + unsetStringColumn(); + } else { + setStringColumn((String)value); + } + break; + + case ENUM_COLUMN: + if (value == null) { + unsetEnumColumn(); + } else { + setEnumColumn((Suit)value); + } + break; + + case MAYBE_BOOL_COLUMN: + if (value == null) { + unsetMaybeBoolColumn(); + } else { + setMaybeBoolColumn((Boolean)value); + } + break; + + case MAYBE_BYTE_COLUMN: + if (value == null) { + unsetMaybeByteColumn(); + } else { + setMaybeByteColumn((Byte)value); + } + break; + + case MAYBE_SHORT_COLUMN: + if (value == null) { + unsetMaybeShortColumn(); + } else { + setMaybeShortColumn((Short)value); + } + break; + + case MAYBE_INT_COLUMN: + if (value == null) { + unsetMaybeIntColumn(); + } else { + setMaybeIntColumn((Integer)value); + } + break; + + case MAYBE_LONG_COLUMN: + if (value == null) { + unsetMaybeLongColumn(); + } else { + setMaybeLongColumn((Long)value); + } + break; + + case MAYBE_DOUBLE_COLUMN: + if (value == null) { + unsetMaybeDoubleColumn(); + } else { + setMaybeDoubleColumn((Double)value); + } + break; + + case MAYBE_BINARY_COLUMN: + if (value == null) { + unsetMaybeBinaryColumn(); + } else { + setMaybeBinaryColumn((ByteBuffer)value); + } + break; + + case MAYBE_STRING_COLUMN: + if (value == null) { + unsetMaybeStringColumn(); + } else { + setMaybeStringColumn((String)value); + } + break; + + case MAYBE_ENUM_COLUMN: + if (value == null) { + unsetMaybeEnumColumn(); + } else { + setMaybeEnumColumn((Suit)value); + } + break; + + case STRINGS_COLUMN: + if (value == null) { + unsetStringsColumn(); + } else { + setStringsColumn((List)value); + } + break; + + case INT_SET_COLUMN: + if (value == null) { + unsetIntSetColumn(); + } else { + setIntSetColumn((Set)value); + } + break; + + case INT_TO_STRING_COLUMN: + if (value == null) { + unsetIntToStringColumn(); + } else { + setIntToStringColumn((Map)value); + } + break; + + case COMPLEX_COLUMN: + if (value == null) { + unsetComplexColumn(); + } else { + setComplexColumn((Map>)value); + } + break; + + } + } + + public Object getFieldValue(_Fields field) { + switch (field) { + case BOOL_COLUMN: + return Boolean.valueOf(isBoolColumn()); + + case BYTE_COLUMN: + return Byte.valueOf(getByteColumn()); + + case SHORT_COLUMN: + return Short.valueOf(getShortColumn()); + + case INT_COLUMN: + return Integer.valueOf(getIntColumn()); + + case LONG_COLUMN: + return Long.valueOf(getLongColumn()); + + case DOUBLE_COLUMN: + return Double.valueOf(getDoubleColumn()); + + case BINARY_COLUMN: + return getBinaryColumn(); + + case STRING_COLUMN: + return getStringColumn(); + + case ENUM_COLUMN: + return getEnumColumn(); + + case MAYBE_BOOL_COLUMN: + return Boolean.valueOf(isMaybeBoolColumn()); + + case MAYBE_BYTE_COLUMN: + return Byte.valueOf(getMaybeByteColumn()); + + case MAYBE_SHORT_COLUMN: + return Short.valueOf(getMaybeShortColumn()); + + case MAYBE_INT_COLUMN: + return Integer.valueOf(getMaybeIntColumn()); + + case MAYBE_LONG_COLUMN: + return Long.valueOf(getMaybeLongColumn()); + + case MAYBE_DOUBLE_COLUMN: + return Double.valueOf(getMaybeDoubleColumn()); + + case MAYBE_BINARY_COLUMN: + return getMaybeBinaryColumn(); + + case MAYBE_STRING_COLUMN: + return getMaybeStringColumn(); + + case MAYBE_ENUM_COLUMN: + return getMaybeEnumColumn(); + + case STRINGS_COLUMN: + return getStringsColumn(); + + case INT_SET_COLUMN: + return getIntSetColumn(); + + case INT_TO_STRING_COLUMN: + return getIntToStringColumn(); + + case COMPLEX_COLUMN: + return getComplexColumn(); + + } + throw new IllegalStateException(); + } + + /** Returns true if field corresponding to fieldID is set (has been assigned a value) and false otherwise */ + public boolean isSet(_Fields field) { + if (field == null) { + throw new IllegalArgumentException(); + } + + switch (field) { + case BOOL_COLUMN: + return isSetBoolColumn(); + case BYTE_COLUMN: + return isSetByteColumn(); + case SHORT_COLUMN: + return isSetShortColumn(); + case INT_COLUMN: + return isSetIntColumn(); + case LONG_COLUMN: + return isSetLongColumn(); + case DOUBLE_COLUMN: + return isSetDoubleColumn(); + case BINARY_COLUMN: + return isSetBinaryColumn(); + case STRING_COLUMN: + return isSetStringColumn(); + case ENUM_COLUMN: + return isSetEnumColumn(); + case MAYBE_BOOL_COLUMN: + return isSetMaybeBoolColumn(); + case MAYBE_BYTE_COLUMN: + return isSetMaybeByteColumn(); + case MAYBE_SHORT_COLUMN: + return isSetMaybeShortColumn(); + case MAYBE_INT_COLUMN: + return isSetMaybeIntColumn(); + case MAYBE_LONG_COLUMN: + return isSetMaybeLongColumn(); + case MAYBE_DOUBLE_COLUMN: + return isSetMaybeDoubleColumn(); + case MAYBE_BINARY_COLUMN: + return isSetMaybeBinaryColumn(); + case MAYBE_STRING_COLUMN: + return isSetMaybeStringColumn(); + case MAYBE_ENUM_COLUMN: + return isSetMaybeEnumColumn(); + case STRINGS_COLUMN: + return isSetStringsColumn(); + case INT_SET_COLUMN: + return isSetIntSetColumn(); + case INT_TO_STRING_COLUMN: + return isSetIntToStringColumn(); + case COMPLEX_COLUMN: + return isSetComplexColumn(); + } + throw new IllegalStateException(); + } + + @Override + public boolean equals(Object that) { + if (that == null) + return false; + if (that instanceof ParquetThriftCompat) + return this.equals((ParquetThriftCompat)that); + return false; + } + + public boolean equals(ParquetThriftCompat that) { + if (that == null) + return false; + + boolean this_present_boolColumn = true; + boolean that_present_boolColumn = true; + if (this_present_boolColumn || that_present_boolColumn) { + if (!(this_present_boolColumn && that_present_boolColumn)) + return false; + if (this.boolColumn != that.boolColumn) + return false; + } + + boolean this_present_byteColumn = true; + boolean that_present_byteColumn = true; + if (this_present_byteColumn || that_present_byteColumn) { + if (!(this_present_byteColumn && that_present_byteColumn)) + return false; + if (this.byteColumn != that.byteColumn) + return false; + } + + boolean this_present_shortColumn = true; + boolean that_present_shortColumn = true; + if (this_present_shortColumn || that_present_shortColumn) { + if (!(this_present_shortColumn && that_present_shortColumn)) + return false; + if (this.shortColumn != that.shortColumn) + return false; + } + + boolean this_present_intColumn = true; + boolean that_present_intColumn = true; + if (this_present_intColumn || that_present_intColumn) { + if (!(this_present_intColumn && that_present_intColumn)) + return false; + if (this.intColumn != that.intColumn) + return false; + } + + boolean this_present_longColumn = true; + boolean that_present_longColumn = true; + if (this_present_longColumn || that_present_longColumn) { + if (!(this_present_longColumn && that_present_longColumn)) + return false; + if (this.longColumn != that.longColumn) + return false; + } + + boolean this_present_doubleColumn = true; + boolean that_present_doubleColumn = true; + if (this_present_doubleColumn || that_present_doubleColumn) { + if (!(this_present_doubleColumn && that_present_doubleColumn)) + return false; + if (this.doubleColumn != that.doubleColumn) + return false; + } + + boolean this_present_binaryColumn = true && this.isSetBinaryColumn(); + boolean that_present_binaryColumn = true && that.isSetBinaryColumn(); + if (this_present_binaryColumn || that_present_binaryColumn) { + if (!(this_present_binaryColumn && that_present_binaryColumn)) + return false; + if (!this.binaryColumn.equals(that.binaryColumn)) + return false; + } + + boolean this_present_stringColumn = true && this.isSetStringColumn(); + boolean that_present_stringColumn = true && that.isSetStringColumn(); + if (this_present_stringColumn || that_present_stringColumn) { + if (!(this_present_stringColumn && that_present_stringColumn)) + return false; + if (!this.stringColumn.equals(that.stringColumn)) + return false; + } + + boolean this_present_enumColumn = true && this.isSetEnumColumn(); + boolean that_present_enumColumn = true && that.isSetEnumColumn(); + if (this_present_enumColumn || that_present_enumColumn) { + if (!(this_present_enumColumn && that_present_enumColumn)) + return false; + if (!this.enumColumn.equals(that.enumColumn)) + return false; + } + + boolean this_present_maybeBoolColumn = true && this.isSetMaybeBoolColumn(); + boolean that_present_maybeBoolColumn = true && that.isSetMaybeBoolColumn(); + if (this_present_maybeBoolColumn || that_present_maybeBoolColumn) { + if (!(this_present_maybeBoolColumn && that_present_maybeBoolColumn)) + return false; + if (this.maybeBoolColumn != that.maybeBoolColumn) + return false; + } + + boolean this_present_maybeByteColumn = true && this.isSetMaybeByteColumn(); + boolean that_present_maybeByteColumn = true && that.isSetMaybeByteColumn(); + if (this_present_maybeByteColumn || that_present_maybeByteColumn) { + if (!(this_present_maybeByteColumn && that_present_maybeByteColumn)) + return false; + if (this.maybeByteColumn != that.maybeByteColumn) + return false; + } + + boolean this_present_maybeShortColumn = true && this.isSetMaybeShortColumn(); + boolean that_present_maybeShortColumn = true && that.isSetMaybeShortColumn(); + if (this_present_maybeShortColumn || that_present_maybeShortColumn) { + if (!(this_present_maybeShortColumn && that_present_maybeShortColumn)) + return false; + if (this.maybeShortColumn != that.maybeShortColumn) + return false; + } + + boolean this_present_maybeIntColumn = true && this.isSetMaybeIntColumn(); + boolean that_present_maybeIntColumn = true && that.isSetMaybeIntColumn(); + if (this_present_maybeIntColumn || that_present_maybeIntColumn) { + if (!(this_present_maybeIntColumn && that_present_maybeIntColumn)) + return false; + if (this.maybeIntColumn != that.maybeIntColumn) + return false; + } + + boolean this_present_maybeLongColumn = true && this.isSetMaybeLongColumn(); + boolean that_present_maybeLongColumn = true && that.isSetMaybeLongColumn(); + if (this_present_maybeLongColumn || that_present_maybeLongColumn) { + if (!(this_present_maybeLongColumn && that_present_maybeLongColumn)) + return false; + if (this.maybeLongColumn != that.maybeLongColumn) + return false; + } + + boolean this_present_maybeDoubleColumn = true && this.isSetMaybeDoubleColumn(); + boolean that_present_maybeDoubleColumn = true && that.isSetMaybeDoubleColumn(); + if (this_present_maybeDoubleColumn || that_present_maybeDoubleColumn) { + if (!(this_present_maybeDoubleColumn && that_present_maybeDoubleColumn)) + return false; + if (this.maybeDoubleColumn != that.maybeDoubleColumn) + return false; + } + + boolean this_present_maybeBinaryColumn = true && this.isSetMaybeBinaryColumn(); + boolean that_present_maybeBinaryColumn = true && that.isSetMaybeBinaryColumn(); + if (this_present_maybeBinaryColumn || that_present_maybeBinaryColumn) { + if (!(this_present_maybeBinaryColumn && that_present_maybeBinaryColumn)) + return false; + if (!this.maybeBinaryColumn.equals(that.maybeBinaryColumn)) + return false; + } + + boolean this_present_maybeStringColumn = true && this.isSetMaybeStringColumn(); + boolean that_present_maybeStringColumn = true && that.isSetMaybeStringColumn(); + if (this_present_maybeStringColumn || that_present_maybeStringColumn) { + if (!(this_present_maybeStringColumn && that_present_maybeStringColumn)) + return false; + if (!this.maybeStringColumn.equals(that.maybeStringColumn)) + return false; + } + + boolean this_present_maybeEnumColumn = true && this.isSetMaybeEnumColumn(); + boolean that_present_maybeEnumColumn = true && that.isSetMaybeEnumColumn(); + if (this_present_maybeEnumColumn || that_present_maybeEnumColumn) { + if (!(this_present_maybeEnumColumn && that_present_maybeEnumColumn)) + return false; + if (!this.maybeEnumColumn.equals(that.maybeEnumColumn)) + return false; + } + + boolean this_present_stringsColumn = true && this.isSetStringsColumn(); + boolean that_present_stringsColumn = true && that.isSetStringsColumn(); + if (this_present_stringsColumn || that_present_stringsColumn) { + if (!(this_present_stringsColumn && that_present_stringsColumn)) + return false; + if (!this.stringsColumn.equals(that.stringsColumn)) + return false; + } + + boolean this_present_intSetColumn = true && this.isSetIntSetColumn(); + boolean that_present_intSetColumn = true && that.isSetIntSetColumn(); + if (this_present_intSetColumn || that_present_intSetColumn) { + if (!(this_present_intSetColumn && that_present_intSetColumn)) + return false; + if (!this.intSetColumn.equals(that.intSetColumn)) + return false; + } + + boolean this_present_intToStringColumn = true && this.isSetIntToStringColumn(); + boolean that_present_intToStringColumn = true && that.isSetIntToStringColumn(); + if (this_present_intToStringColumn || that_present_intToStringColumn) { + if (!(this_present_intToStringColumn && that_present_intToStringColumn)) + return false; + if (!this.intToStringColumn.equals(that.intToStringColumn)) + return false; + } + + boolean this_present_complexColumn = true && this.isSetComplexColumn(); + boolean that_present_complexColumn = true && that.isSetComplexColumn(); + if (this_present_complexColumn || that_present_complexColumn) { + if (!(this_present_complexColumn && that_present_complexColumn)) + return false; + if (!this.complexColumn.equals(that.complexColumn)) + return false; + } + + return true; + } + + @Override + public int hashCode() { + List list = new ArrayList(); + + boolean present_boolColumn = true; + list.add(present_boolColumn); + if (present_boolColumn) + list.add(boolColumn); + + boolean present_byteColumn = true; + list.add(present_byteColumn); + if (present_byteColumn) + list.add(byteColumn); + + boolean present_shortColumn = true; + list.add(present_shortColumn); + if (present_shortColumn) + list.add(shortColumn); + + boolean present_intColumn = true; + list.add(present_intColumn); + if (present_intColumn) + list.add(intColumn); + + boolean present_longColumn = true; + list.add(present_longColumn); + if (present_longColumn) + list.add(longColumn); + + boolean present_doubleColumn = true; + list.add(present_doubleColumn); + if (present_doubleColumn) + list.add(doubleColumn); + + boolean present_binaryColumn = true && (isSetBinaryColumn()); + list.add(present_binaryColumn); + if (present_binaryColumn) + list.add(binaryColumn); + + boolean present_stringColumn = true && (isSetStringColumn()); + list.add(present_stringColumn); + if (present_stringColumn) + list.add(stringColumn); + + boolean present_enumColumn = true && (isSetEnumColumn()); + list.add(present_enumColumn); + if (present_enumColumn) + list.add(enumColumn.getValue()); + + boolean present_maybeBoolColumn = true && (isSetMaybeBoolColumn()); + list.add(present_maybeBoolColumn); + if (present_maybeBoolColumn) + list.add(maybeBoolColumn); + + boolean present_maybeByteColumn = true && (isSetMaybeByteColumn()); + list.add(present_maybeByteColumn); + if (present_maybeByteColumn) + list.add(maybeByteColumn); + + boolean present_maybeShortColumn = true && (isSetMaybeShortColumn()); + list.add(present_maybeShortColumn); + if (present_maybeShortColumn) + list.add(maybeShortColumn); + + boolean present_maybeIntColumn = true && (isSetMaybeIntColumn()); + list.add(present_maybeIntColumn); + if (present_maybeIntColumn) + list.add(maybeIntColumn); + + boolean present_maybeLongColumn = true && (isSetMaybeLongColumn()); + list.add(present_maybeLongColumn); + if (present_maybeLongColumn) + list.add(maybeLongColumn); + + boolean present_maybeDoubleColumn = true && (isSetMaybeDoubleColumn()); + list.add(present_maybeDoubleColumn); + if (present_maybeDoubleColumn) + list.add(maybeDoubleColumn); + + boolean present_maybeBinaryColumn = true && (isSetMaybeBinaryColumn()); + list.add(present_maybeBinaryColumn); + if (present_maybeBinaryColumn) + list.add(maybeBinaryColumn); + + boolean present_maybeStringColumn = true && (isSetMaybeStringColumn()); + list.add(present_maybeStringColumn); + if (present_maybeStringColumn) + list.add(maybeStringColumn); + + boolean present_maybeEnumColumn = true && (isSetMaybeEnumColumn()); + list.add(present_maybeEnumColumn); + if (present_maybeEnumColumn) + list.add(maybeEnumColumn.getValue()); + + boolean present_stringsColumn = true && (isSetStringsColumn()); + list.add(present_stringsColumn); + if (present_stringsColumn) + list.add(stringsColumn); + + boolean present_intSetColumn = true && (isSetIntSetColumn()); + list.add(present_intSetColumn); + if (present_intSetColumn) + list.add(intSetColumn); + + boolean present_intToStringColumn = true && (isSetIntToStringColumn()); + list.add(present_intToStringColumn); + if (present_intToStringColumn) + list.add(intToStringColumn); + + boolean present_complexColumn = true && (isSetComplexColumn()); + list.add(present_complexColumn); + if (present_complexColumn) + list.add(complexColumn); + + return list.hashCode(); + } + + @Override + public int compareTo(ParquetThriftCompat other) { + if (!getClass().equals(other.getClass())) { + return getClass().getName().compareTo(other.getClass().getName()); + } + + int lastComparison = 0; + + lastComparison = Boolean.valueOf(isSetBoolColumn()).compareTo(other.isSetBoolColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetBoolColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.boolColumn, other.boolColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetByteColumn()).compareTo(other.isSetByteColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetByteColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.byteColumn, other.byteColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetShortColumn()).compareTo(other.isSetShortColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetShortColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.shortColumn, other.shortColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetIntColumn()).compareTo(other.isSetIntColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetIntColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.intColumn, other.intColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLongColumn()).compareTo(other.isSetLongColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLongColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.longColumn, other.longColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetDoubleColumn()).compareTo(other.isSetDoubleColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetDoubleColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.doubleColumn, other.doubleColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetBinaryColumn()).compareTo(other.isSetBinaryColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetBinaryColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.binaryColumn, other.binaryColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetStringColumn()).compareTo(other.isSetStringColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetStringColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.stringColumn, other.stringColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetEnumColumn()).compareTo(other.isSetEnumColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetEnumColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.enumColumn, other.enumColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeBoolColumn()).compareTo(other.isSetMaybeBoolColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeBoolColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeBoolColumn, other.maybeBoolColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeByteColumn()).compareTo(other.isSetMaybeByteColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeByteColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeByteColumn, other.maybeByteColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeShortColumn()).compareTo(other.isSetMaybeShortColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeShortColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeShortColumn, other.maybeShortColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeIntColumn()).compareTo(other.isSetMaybeIntColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeIntColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeIntColumn, other.maybeIntColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeLongColumn()).compareTo(other.isSetMaybeLongColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeLongColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeLongColumn, other.maybeLongColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeDoubleColumn()).compareTo(other.isSetMaybeDoubleColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeDoubleColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeDoubleColumn, other.maybeDoubleColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeBinaryColumn()).compareTo(other.isSetMaybeBinaryColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeBinaryColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeBinaryColumn, other.maybeBinaryColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeStringColumn()).compareTo(other.isSetMaybeStringColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeStringColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeStringColumn, other.maybeStringColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMaybeEnumColumn()).compareTo(other.isSetMaybeEnumColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMaybeEnumColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.maybeEnumColumn, other.maybeEnumColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetStringsColumn()).compareTo(other.isSetStringsColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetStringsColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.stringsColumn, other.stringsColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetIntSetColumn()).compareTo(other.isSetIntSetColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetIntSetColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.intSetColumn, other.intSetColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetIntToStringColumn()).compareTo(other.isSetIntToStringColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetIntToStringColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.intToStringColumn, other.intToStringColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetComplexColumn()).compareTo(other.isSetComplexColumn()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetComplexColumn()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.complexColumn, other.complexColumn); + if (lastComparison != 0) { + return lastComparison; + } + } + return 0; + } + + public _Fields fieldForId(int fieldId) { + return _Fields.findByThriftId(fieldId); + } + + public void read(org.apache.thrift.protocol.TProtocol iprot) throws org.apache.thrift.TException { + schemes.get(iprot.getScheme()).getScheme().read(iprot, this); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException { + schemes.get(oprot.getScheme()).getScheme().write(oprot, this); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("ParquetThriftCompat("); + boolean first = true; + + sb.append("boolColumn:"); + sb.append(this.boolColumn); + first = false; + if (!first) sb.append(", "); + sb.append("byteColumn:"); + sb.append(this.byteColumn); + first = false; + if (!first) sb.append(", "); + sb.append("shortColumn:"); + sb.append(this.shortColumn); + first = false; + if (!first) sb.append(", "); + sb.append("intColumn:"); + sb.append(this.intColumn); + first = false; + if (!first) sb.append(", "); + sb.append("longColumn:"); + sb.append(this.longColumn); + first = false; + if (!first) sb.append(", "); + sb.append("doubleColumn:"); + sb.append(this.doubleColumn); + first = false; + if (!first) sb.append(", "); + sb.append("binaryColumn:"); + if (this.binaryColumn == null) { + sb.append("null"); + } else { + org.apache.thrift.TBaseHelper.toString(this.binaryColumn, sb); + } + first = false; + if (!first) sb.append(", "); + sb.append("stringColumn:"); + if (this.stringColumn == null) { + sb.append("null"); + } else { + sb.append(this.stringColumn); + } + first = false; + if (!first) sb.append(", "); + sb.append("enumColumn:"); + if (this.enumColumn == null) { + sb.append("null"); + } else { + sb.append(this.enumColumn); + } + first = false; + if (isSetMaybeBoolColumn()) { + if (!first) sb.append(", "); + sb.append("maybeBoolColumn:"); + sb.append(this.maybeBoolColumn); + first = false; + } + if (isSetMaybeByteColumn()) { + if (!first) sb.append(", "); + sb.append("maybeByteColumn:"); + sb.append(this.maybeByteColumn); + first = false; + } + if (isSetMaybeShortColumn()) { + if (!first) sb.append(", "); + sb.append("maybeShortColumn:"); + sb.append(this.maybeShortColumn); + first = false; + } + if (isSetMaybeIntColumn()) { + if (!first) sb.append(", "); + sb.append("maybeIntColumn:"); + sb.append(this.maybeIntColumn); + first = false; + } + if (isSetMaybeLongColumn()) { + if (!first) sb.append(", "); + sb.append("maybeLongColumn:"); + sb.append(this.maybeLongColumn); + first = false; + } + if (isSetMaybeDoubleColumn()) { + if (!first) sb.append(", "); + sb.append("maybeDoubleColumn:"); + sb.append(this.maybeDoubleColumn); + first = false; + } + if (isSetMaybeBinaryColumn()) { + if (!first) sb.append(", "); + sb.append("maybeBinaryColumn:"); + if (this.maybeBinaryColumn == null) { + sb.append("null"); + } else { + org.apache.thrift.TBaseHelper.toString(this.maybeBinaryColumn, sb); + } + first = false; + } + if (isSetMaybeStringColumn()) { + if (!first) sb.append(", "); + sb.append("maybeStringColumn:"); + if (this.maybeStringColumn == null) { + sb.append("null"); + } else { + sb.append(this.maybeStringColumn); + } + first = false; + } + if (isSetMaybeEnumColumn()) { + if (!first) sb.append(", "); + sb.append("maybeEnumColumn:"); + if (this.maybeEnumColumn == null) { + sb.append("null"); + } else { + sb.append(this.maybeEnumColumn); + } + first = false; + } + if (!first) sb.append(", "); + sb.append("stringsColumn:"); + if (this.stringsColumn == null) { + sb.append("null"); + } else { + sb.append(this.stringsColumn); + } + first = false; + if (!first) sb.append(", "); + sb.append("intSetColumn:"); + if (this.intSetColumn == null) { + sb.append("null"); + } else { + sb.append(this.intSetColumn); + } + first = false; + if (!first) sb.append(", "); + sb.append("intToStringColumn:"); + if (this.intToStringColumn == null) { + sb.append("null"); + } else { + sb.append(this.intToStringColumn); + } + first = false; + if (!first) sb.append(", "); + sb.append("complexColumn:"); + if (this.complexColumn == null) { + sb.append("null"); + } else { + sb.append(this.complexColumn); + } + first = false; + sb.append(")"); + return sb.toString(); + } + + public void validate() throws org.apache.thrift.TException { + // check for required fields + // alas, we cannot check 'boolColumn' because it's a primitive and you chose the non-beans generator. + // alas, we cannot check 'byteColumn' because it's a primitive and you chose the non-beans generator. + // alas, we cannot check 'shortColumn' because it's a primitive and you chose the non-beans generator. + // alas, we cannot check 'intColumn' because it's a primitive and you chose the non-beans generator. + // alas, we cannot check 'longColumn' because it's a primitive and you chose the non-beans generator. + // alas, we cannot check 'doubleColumn' because it's a primitive and you chose the non-beans generator. + if (binaryColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'binaryColumn' was not present! Struct: " + toString()); + } + if (stringColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'stringColumn' was not present! Struct: " + toString()); + } + if (enumColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'enumColumn' was not present! Struct: " + toString()); + } + if (stringsColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'stringsColumn' was not present! Struct: " + toString()); + } + if (intSetColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'intSetColumn' was not present! Struct: " + toString()); + } + if (intToStringColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'intToStringColumn' was not present! Struct: " + toString()); + } + if (complexColumn == null) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'complexColumn' was not present! Struct: " + toString()); + } + // check for sub-struct validity + } + + private void writeObject(java.io.ObjectOutputStream out) throws java.io.IOException { + try { + write(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(out))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private void readObject(java.io.ObjectInputStream in) throws java.io.IOException, ClassNotFoundException { + try { + // it doesn't seem like you should have to do this, but java serialization is wacky, and doesn't call the default constructor. + __isset_bitfield = 0; + read(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(in))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private static class ParquetThriftCompatStandardSchemeFactory implements SchemeFactory { + public ParquetThriftCompatStandardScheme getScheme() { + return new ParquetThriftCompatStandardScheme(); + } + } + + private static class ParquetThriftCompatStandardScheme extends StandardScheme { + + public void read(org.apache.thrift.protocol.TProtocol iprot, ParquetThriftCompat struct) throws org.apache.thrift.TException { + org.apache.thrift.protocol.TField schemeField; + iprot.readStructBegin(); + while (true) + { + schemeField = iprot.readFieldBegin(); + if (schemeField.type == org.apache.thrift.protocol.TType.STOP) { + break; + } + switch (schemeField.id) { + case 1: // BOOL_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.BOOL) { + struct.boolColumn = iprot.readBool(); + struct.setBoolColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 2: // BYTE_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.BYTE) { + struct.byteColumn = iprot.readByte(); + struct.setByteColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 3: // SHORT_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.I16) { + struct.shortColumn = iprot.readI16(); + struct.setShortColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 4: // INT_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.I32) { + struct.intColumn = iprot.readI32(); + struct.setIntColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 5: // LONG_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.I64) { + struct.longColumn = iprot.readI64(); + struct.setLongColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 6: // DOUBLE_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.DOUBLE) { + struct.doubleColumn = iprot.readDouble(); + struct.setDoubleColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 7: // BINARY_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.STRING) { + struct.binaryColumn = iprot.readBinary(); + struct.setBinaryColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 8: // STRING_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.STRING) { + struct.stringColumn = iprot.readString(); + struct.setStringColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 9: // ENUM_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.I32) { + struct.enumColumn = org.apache.spark.sql.parquet.test.thrift.Suit.findByValue(iprot.readI32()); + struct.setEnumColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 10: // MAYBE_BOOL_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.BOOL) { + struct.maybeBoolColumn = iprot.readBool(); + struct.setMaybeBoolColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 11: // MAYBE_BYTE_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.BYTE) { + struct.maybeByteColumn = iprot.readByte(); + struct.setMaybeByteColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 12: // MAYBE_SHORT_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.I16) { + struct.maybeShortColumn = iprot.readI16(); + struct.setMaybeShortColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 13: // MAYBE_INT_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.I32) { + struct.maybeIntColumn = iprot.readI32(); + struct.setMaybeIntColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 14: // MAYBE_LONG_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.I64) { + struct.maybeLongColumn = iprot.readI64(); + struct.setMaybeLongColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 15: // MAYBE_DOUBLE_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.DOUBLE) { + struct.maybeDoubleColumn = iprot.readDouble(); + struct.setMaybeDoubleColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 16: // MAYBE_BINARY_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.STRING) { + struct.maybeBinaryColumn = iprot.readBinary(); + struct.setMaybeBinaryColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 17: // MAYBE_STRING_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.STRING) { + struct.maybeStringColumn = iprot.readString(); + struct.setMaybeStringColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 18: // MAYBE_ENUM_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.I32) { + struct.maybeEnumColumn = org.apache.spark.sql.parquet.test.thrift.Suit.findByValue(iprot.readI32()); + struct.setMaybeEnumColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 19: // STRINGS_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list8 = iprot.readListBegin(); + struct.stringsColumn = new ArrayList(_list8.size); + String _elem9; + for (int _i10 = 0; _i10 < _list8.size; ++_i10) + { + _elem9 = iprot.readString(); + struct.stringsColumn.add(_elem9); + } + iprot.readListEnd(); + } + struct.setStringsColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 20: // INT_SET_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.SET) { + { + org.apache.thrift.protocol.TSet _set11 = iprot.readSetBegin(); + struct.intSetColumn = new HashSet(2*_set11.size); + int _elem12; + for (int _i13 = 0; _i13 < _set11.size; ++_i13) + { + _elem12 = iprot.readI32(); + struct.intSetColumn.add(_elem12); + } + iprot.readSetEnd(); + } + struct.setIntSetColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 21: // INT_TO_STRING_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.MAP) { + { + org.apache.thrift.protocol.TMap _map14 = iprot.readMapBegin(); + struct.intToStringColumn = new HashMap(2*_map14.size); + int _key15; + String _val16; + for (int _i17 = 0; _i17 < _map14.size; ++_i17) + { + _key15 = iprot.readI32(); + _val16 = iprot.readString(); + struct.intToStringColumn.put(_key15, _val16); + } + iprot.readMapEnd(); + } + struct.setIntToStringColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 22: // COMPLEX_COLUMN + if (schemeField.type == org.apache.thrift.protocol.TType.MAP) { + { + org.apache.thrift.protocol.TMap _map18 = iprot.readMapBegin(); + struct.complexColumn = new HashMap>(2*_map18.size); + int _key19; + List _val20; + for (int _i21 = 0; _i21 < _map18.size; ++_i21) + { + _key19 = iprot.readI32(); + { + org.apache.thrift.protocol.TList _list22 = iprot.readListBegin(); + _val20 = new ArrayList(_list22.size); + Nested _elem23; + for (int _i24 = 0; _i24 < _list22.size; ++_i24) + { + _elem23 = new Nested(); + _elem23.read(iprot); + _val20.add(_elem23); + } + iprot.readListEnd(); + } + struct.complexColumn.put(_key19, _val20); + } + iprot.readMapEnd(); + } + struct.setComplexColumnIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + default: + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + iprot.readFieldEnd(); + } + iprot.readStructEnd(); + + // check for required fields of primitive type, which can't be checked in the validate method + if (!struct.isSetBoolColumn()) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'boolColumn' was not found in serialized data! Struct: " + toString()); + } + if (!struct.isSetByteColumn()) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'byteColumn' was not found in serialized data! Struct: " + toString()); + } + if (!struct.isSetShortColumn()) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'shortColumn' was not found in serialized data! Struct: " + toString()); + } + if (!struct.isSetIntColumn()) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'intColumn' was not found in serialized data! Struct: " + toString()); + } + if (!struct.isSetLongColumn()) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'longColumn' was not found in serialized data! Struct: " + toString()); + } + if (!struct.isSetDoubleColumn()) { + throw new org.apache.thrift.protocol.TProtocolException("Required field 'doubleColumn' was not found in serialized data! Struct: " + toString()); + } + struct.validate(); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot, ParquetThriftCompat struct) throws org.apache.thrift.TException { + struct.validate(); + + oprot.writeStructBegin(STRUCT_DESC); + oprot.writeFieldBegin(BOOL_COLUMN_FIELD_DESC); + oprot.writeBool(struct.boolColumn); + oprot.writeFieldEnd(); + oprot.writeFieldBegin(BYTE_COLUMN_FIELD_DESC); + oprot.writeByte(struct.byteColumn); + oprot.writeFieldEnd(); + oprot.writeFieldBegin(SHORT_COLUMN_FIELD_DESC); + oprot.writeI16(struct.shortColumn); + oprot.writeFieldEnd(); + oprot.writeFieldBegin(INT_COLUMN_FIELD_DESC); + oprot.writeI32(struct.intColumn); + oprot.writeFieldEnd(); + oprot.writeFieldBegin(LONG_COLUMN_FIELD_DESC); + oprot.writeI64(struct.longColumn); + oprot.writeFieldEnd(); + oprot.writeFieldBegin(DOUBLE_COLUMN_FIELD_DESC); + oprot.writeDouble(struct.doubleColumn); + oprot.writeFieldEnd(); + if (struct.binaryColumn != null) { + oprot.writeFieldBegin(BINARY_COLUMN_FIELD_DESC); + oprot.writeBinary(struct.binaryColumn); + oprot.writeFieldEnd(); + } + if (struct.stringColumn != null) { + oprot.writeFieldBegin(STRING_COLUMN_FIELD_DESC); + oprot.writeString(struct.stringColumn); + oprot.writeFieldEnd(); + } + if (struct.enumColumn != null) { + oprot.writeFieldBegin(ENUM_COLUMN_FIELD_DESC); + oprot.writeI32(struct.enumColumn.getValue()); + oprot.writeFieldEnd(); + } + if (struct.isSetMaybeBoolColumn()) { + oprot.writeFieldBegin(MAYBE_BOOL_COLUMN_FIELD_DESC); + oprot.writeBool(struct.maybeBoolColumn); + oprot.writeFieldEnd(); + } + if (struct.isSetMaybeByteColumn()) { + oprot.writeFieldBegin(MAYBE_BYTE_COLUMN_FIELD_DESC); + oprot.writeByte(struct.maybeByteColumn); + oprot.writeFieldEnd(); + } + if (struct.isSetMaybeShortColumn()) { + oprot.writeFieldBegin(MAYBE_SHORT_COLUMN_FIELD_DESC); + oprot.writeI16(struct.maybeShortColumn); + oprot.writeFieldEnd(); + } + if (struct.isSetMaybeIntColumn()) { + oprot.writeFieldBegin(MAYBE_INT_COLUMN_FIELD_DESC); + oprot.writeI32(struct.maybeIntColumn); + oprot.writeFieldEnd(); + } + if (struct.isSetMaybeLongColumn()) { + oprot.writeFieldBegin(MAYBE_LONG_COLUMN_FIELD_DESC); + oprot.writeI64(struct.maybeLongColumn); + oprot.writeFieldEnd(); + } + if (struct.isSetMaybeDoubleColumn()) { + oprot.writeFieldBegin(MAYBE_DOUBLE_COLUMN_FIELD_DESC); + oprot.writeDouble(struct.maybeDoubleColumn); + oprot.writeFieldEnd(); + } + if (struct.maybeBinaryColumn != null) { + if (struct.isSetMaybeBinaryColumn()) { + oprot.writeFieldBegin(MAYBE_BINARY_COLUMN_FIELD_DESC); + oprot.writeBinary(struct.maybeBinaryColumn); + oprot.writeFieldEnd(); + } + } + if (struct.maybeStringColumn != null) { + if (struct.isSetMaybeStringColumn()) { + oprot.writeFieldBegin(MAYBE_STRING_COLUMN_FIELD_DESC); + oprot.writeString(struct.maybeStringColumn); + oprot.writeFieldEnd(); + } + } + if (struct.maybeEnumColumn != null) { + if (struct.isSetMaybeEnumColumn()) { + oprot.writeFieldBegin(MAYBE_ENUM_COLUMN_FIELD_DESC); + oprot.writeI32(struct.maybeEnumColumn.getValue()); + oprot.writeFieldEnd(); + } + } + if (struct.stringsColumn != null) { + oprot.writeFieldBegin(STRINGS_COLUMN_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, struct.stringsColumn.size())); + for (String _iter25 : struct.stringsColumn) + { + oprot.writeString(_iter25); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.intSetColumn != null) { + oprot.writeFieldBegin(INT_SET_COLUMN_FIELD_DESC); + { + oprot.writeSetBegin(new org.apache.thrift.protocol.TSet(org.apache.thrift.protocol.TType.I32, struct.intSetColumn.size())); + for (int _iter26 : struct.intSetColumn) + { + oprot.writeI32(_iter26); + } + oprot.writeSetEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.intToStringColumn != null) { + oprot.writeFieldBegin(INT_TO_STRING_COLUMN_FIELD_DESC); + { + oprot.writeMapBegin(new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.I32, org.apache.thrift.protocol.TType.STRING, struct.intToStringColumn.size())); + for (Map.Entry _iter27 : struct.intToStringColumn.entrySet()) + { + oprot.writeI32(_iter27.getKey()); + oprot.writeString(_iter27.getValue()); + } + oprot.writeMapEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.complexColumn != null) { + oprot.writeFieldBegin(COMPLEX_COLUMN_FIELD_DESC); + { + oprot.writeMapBegin(new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.I32, org.apache.thrift.protocol.TType.LIST, struct.complexColumn.size())); + for (Map.Entry> _iter28 : struct.complexColumn.entrySet()) + { + oprot.writeI32(_iter28.getKey()); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRUCT, _iter28.getValue().size())); + for (Nested _iter29 : _iter28.getValue()) + { + _iter29.write(oprot); + } + oprot.writeListEnd(); + } + } + oprot.writeMapEnd(); + } + oprot.writeFieldEnd(); + } + oprot.writeFieldStop(); + oprot.writeStructEnd(); + } + + } + + private static class ParquetThriftCompatTupleSchemeFactory implements SchemeFactory { + public ParquetThriftCompatTupleScheme getScheme() { + return new ParquetThriftCompatTupleScheme(); + } + } + + private static class ParquetThriftCompatTupleScheme extends TupleScheme { + + @Override + public void write(org.apache.thrift.protocol.TProtocol prot, ParquetThriftCompat struct) throws org.apache.thrift.TException { + TTupleProtocol oprot = (TTupleProtocol) prot; + oprot.writeBool(struct.boolColumn); + oprot.writeByte(struct.byteColumn); + oprot.writeI16(struct.shortColumn); + oprot.writeI32(struct.intColumn); + oprot.writeI64(struct.longColumn); + oprot.writeDouble(struct.doubleColumn); + oprot.writeBinary(struct.binaryColumn); + oprot.writeString(struct.stringColumn); + oprot.writeI32(struct.enumColumn.getValue()); + { + oprot.writeI32(struct.stringsColumn.size()); + for (String _iter30 : struct.stringsColumn) + { + oprot.writeString(_iter30); + } + } + { + oprot.writeI32(struct.intSetColumn.size()); + for (int _iter31 : struct.intSetColumn) + { + oprot.writeI32(_iter31); + } + } + { + oprot.writeI32(struct.intToStringColumn.size()); + for (Map.Entry _iter32 : struct.intToStringColumn.entrySet()) + { + oprot.writeI32(_iter32.getKey()); + oprot.writeString(_iter32.getValue()); + } + } + { + oprot.writeI32(struct.complexColumn.size()); + for (Map.Entry> _iter33 : struct.complexColumn.entrySet()) + { + oprot.writeI32(_iter33.getKey()); + { + oprot.writeI32(_iter33.getValue().size()); + for (Nested _iter34 : _iter33.getValue()) + { + _iter34.write(oprot); + } + } + } + } + BitSet optionals = new BitSet(); + if (struct.isSetMaybeBoolColumn()) { + optionals.set(0); + } + if (struct.isSetMaybeByteColumn()) { + optionals.set(1); + } + if (struct.isSetMaybeShortColumn()) { + optionals.set(2); + } + if (struct.isSetMaybeIntColumn()) { + optionals.set(3); + } + if (struct.isSetMaybeLongColumn()) { + optionals.set(4); + } + if (struct.isSetMaybeDoubleColumn()) { + optionals.set(5); + } + if (struct.isSetMaybeBinaryColumn()) { + optionals.set(6); + } + if (struct.isSetMaybeStringColumn()) { + optionals.set(7); + } + if (struct.isSetMaybeEnumColumn()) { + optionals.set(8); + } + oprot.writeBitSet(optionals, 9); + if (struct.isSetMaybeBoolColumn()) { + oprot.writeBool(struct.maybeBoolColumn); + } + if (struct.isSetMaybeByteColumn()) { + oprot.writeByte(struct.maybeByteColumn); + } + if (struct.isSetMaybeShortColumn()) { + oprot.writeI16(struct.maybeShortColumn); + } + if (struct.isSetMaybeIntColumn()) { + oprot.writeI32(struct.maybeIntColumn); + } + if (struct.isSetMaybeLongColumn()) { + oprot.writeI64(struct.maybeLongColumn); + } + if (struct.isSetMaybeDoubleColumn()) { + oprot.writeDouble(struct.maybeDoubleColumn); + } + if (struct.isSetMaybeBinaryColumn()) { + oprot.writeBinary(struct.maybeBinaryColumn); + } + if (struct.isSetMaybeStringColumn()) { + oprot.writeString(struct.maybeStringColumn); + } + if (struct.isSetMaybeEnumColumn()) { + oprot.writeI32(struct.maybeEnumColumn.getValue()); + } + } + + @Override + public void read(org.apache.thrift.protocol.TProtocol prot, ParquetThriftCompat struct) throws org.apache.thrift.TException { + TTupleProtocol iprot = (TTupleProtocol) prot; + struct.boolColumn = iprot.readBool(); + struct.setBoolColumnIsSet(true); + struct.byteColumn = iprot.readByte(); + struct.setByteColumnIsSet(true); + struct.shortColumn = iprot.readI16(); + struct.setShortColumnIsSet(true); + struct.intColumn = iprot.readI32(); + struct.setIntColumnIsSet(true); + struct.longColumn = iprot.readI64(); + struct.setLongColumnIsSet(true); + struct.doubleColumn = iprot.readDouble(); + struct.setDoubleColumnIsSet(true); + struct.binaryColumn = iprot.readBinary(); + struct.setBinaryColumnIsSet(true); + struct.stringColumn = iprot.readString(); + struct.setStringColumnIsSet(true); + struct.enumColumn = org.apache.spark.sql.parquet.test.thrift.Suit.findByValue(iprot.readI32()); + struct.setEnumColumnIsSet(true); + { + org.apache.thrift.protocol.TList _list35 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, iprot.readI32()); + struct.stringsColumn = new ArrayList(_list35.size); + String _elem36; + for (int _i37 = 0; _i37 < _list35.size; ++_i37) + { + _elem36 = iprot.readString(); + struct.stringsColumn.add(_elem36); + } + } + struct.setStringsColumnIsSet(true); + { + org.apache.thrift.protocol.TSet _set38 = new org.apache.thrift.protocol.TSet(org.apache.thrift.protocol.TType.I32, iprot.readI32()); + struct.intSetColumn = new HashSet(2*_set38.size); + int _elem39; + for (int _i40 = 0; _i40 < _set38.size; ++_i40) + { + _elem39 = iprot.readI32(); + struct.intSetColumn.add(_elem39); + } + } + struct.setIntSetColumnIsSet(true); + { + org.apache.thrift.protocol.TMap _map41 = new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.I32, org.apache.thrift.protocol.TType.STRING, iprot.readI32()); + struct.intToStringColumn = new HashMap(2*_map41.size); + int _key42; + String _val43; + for (int _i44 = 0; _i44 < _map41.size; ++_i44) + { + _key42 = iprot.readI32(); + _val43 = iprot.readString(); + struct.intToStringColumn.put(_key42, _val43); + } + } + struct.setIntToStringColumnIsSet(true); + { + org.apache.thrift.protocol.TMap _map45 = new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.I32, org.apache.thrift.protocol.TType.LIST, iprot.readI32()); + struct.complexColumn = new HashMap>(2*_map45.size); + int _key46; + List _val47; + for (int _i48 = 0; _i48 < _map45.size; ++_i48) + { + _key46 = iprot.readI32(); + { + org.apache.thrift.protocol.TList _list49 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRUCT, iprot.readI32()); + _val47 = new ArrayList(_list49.size); + Nested _elem50; + for (int _i51 = 0; _i51 < _list49.size; ++_i51) + { + _elem50 = new Nested(); + _elem50.read(iprot); + _val47.add(_elem50); + } + } + struct.complexColumn.put(_key46, _val47); + } + } + struct.setComplexColumnIsSet(true); + BitSet incoming = iprot.readBitSet(9); + if (incoming.get(0)) { + struct.maybeBoolColumn = iprot.readBool(); + struct.setMaybeBoolColumnIsSet(true); + } + if (incoming.get(1)) { + struct.maybeByteColumn = iprot.readByte(); + struct.setMaybeByteColumnIsSet(true); + } + if (incoming.get(2)) { + struct.maybeShortColumn = iprot.readI16(); + struct.setMaybeShortColumnIsSet(true); + } + if (incoming.get(3)) { + struct.maybeIntColumn = iprot.readI32(); + struct.setMaybeIntColumnIsSet(true); + } + if (incoming.get(4)) { + struct.maybeLongColumn = iprot.readI64(); + struct.setMaybeLongColumnIsSet(true); + } + if (incoming.get(5)) { + struct.maybeDoubleColumn = iprot.readDouble(); + struct.setMaybeDoubleColumnIsSet(true); + } + if (incoming.get(6)) { + struct.maybeBinaryColumn = iprot.readBinary(); + struct.setMaybeBinaryColumnIsSet(true); + } + if (incoming.get(7)) { + struct.maybeStringColumn = iprot.readString(); + struct.setMaybeStringColumnIsSet(true); + } + if (incoming.get(8)) { + struct.maybeEnumColumn = org.apache.spark.sql.parquet.test.thrift.Suit.findByValue(iprot.readI32()); + struct.setMaybeEnumColumnIsSet(true); + } + } + } + +} + diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Suit.java b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Suit.java new file mode 100644 index 0000000000000..5315c6aae9372 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Suit.java @@ -0,0 +1,51 @@ +/** + * Autogenerated by Thrift Compiler (0.9.2) + * + * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING + * @generated + */ +package org.apache.spark.sql.parquet.test.thrift; + + +import java.util.Map; +import java.util.HashMap; +import org.apache.thrift.TEnum; + +public enum Suit implements org.apache.thrift.TEnum { + SPADES(0), + HEARTS(1), + DIAMONDS(2), + CLUBS(3); + + private final int value; + + private Suit(int value) { + this.value = value; + } + + /** + * Get the integer value of this enum value, as defined in the Thrift IDL. + */ + public int getValue() { + return value; + } + + /** + * Find a the enum type by its integer value, as defined in the Thrift IDL. + * @return null if the value is not found. + */ + public static Suit findByValue(int value) { + switch (value) { + case 0: + return SPADES; + case 1: + return HEARTS; + case 2: + return DIAMONDS; + case 3: + return CLUBS; + default: + return null; + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala new file mode 100644 index 0000000000000..bfa427349ff6a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.nio.ByteBuffer +import java.util.{List => JList, Map => JMap} + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.fs.Path +import org.apache.parquet.avro.AvroParquetWriter + +import org.apache.spark.sql.parquet.test.avro.{Nested, ParquetAvroCompat} +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.{Row, SQLContext} + +class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest { + import ParquetCompatibilityTest._ + + override val sqlContext: SQLContext = TestSQLContext + + override protected def beforeAll(): Unit = { + super.beforeAll() + + val writer = + new AvroParquetWriter[ParquetAvroCompat]( + new Path(parquetStore.getCanonicalPath), + ParquetAvroCompat.getClassSchema) + + (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) + writer.close() + } + + test("Read Parquet file generated by parquet-avro") { + logInfo( + s"""Schema of the Parquet file written by parquet-avro: + |${readParquetSchema(parquetStore.getCanonicalPath)} + """.stripMargin) + + checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), (0 until 10).map { i => + def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) + + Row( + i % 2 == 0, + i, + i.toLong * 10, + i.toFloat + 0.1f, + i.toDouble + 0.2d, + s"val_$i".getBytes, + s"val_$i", + + nullable(i % 2 == 0: java.lang.Boolean), + nullable(i: Integer), + nullable(i.toLong: java.lang.Long), + nullable(i.toFloat + 0.1f: java.lang.Float), + nullable(i.toDouble + 0.2d: java.lang.Double), + nullable(s"val_$i".getBytes), + nullable(s"val_$i"), + + Seq.tabulate(3)(n => s"arr_${i + n}"), + Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, + Seq.tabulate(3) { n => + (i + n).toString -> Seq.tabulate(3) { m => + Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") + } + }.toMap) + }) + } + + def makeParquetAvroCompat(i: Int): ParquetAvroCompat = { + def nullable[T <: AnyRef] = makeNullable[T](i) _ + + def makeComplexColumn(i: Int): JMap[String, JList[Nested]] = { + mapAsJavaMap(Seq.tabulate(3) { n => + (i + n).toString -> seqAsJavaList(Seq.tabulate(3) { m => + Nested + .newBuilder() + .setNestedIntsColumn(seqAsJavaList(Seq.tabulate(3)(j => i + j + m))) + .setNestedStringColumn(s"val_${i + m}") + .build() + }) + }.toMap) + } + + ParquetAvroCompat + .newBuilder() + .setBoolColumn(i % 2 == 0) + .setIntColumn(i) + .setLongColumn(i.toLong * 10) + .setFloatColumn(i.toFloat + 0.1f) + .setDoubleColumn(i.toDouble + 0.2d) + .setBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes)) + .setStringColumn(s"val_$i") + + .setMaybeBoolColumn(nullable(i % 2 == 0: java.lang.Boolean)) + .setMaybeIntColumn(nullable(i: Integer)) + .setMaybeLongColumn(nullable(i.toLong: java.lang.Long)) + .setMaybeFloatColumn(nullable(i.toFloat + 0.1f: java.lang.Float)) + .setMaybeDoubleColumn(nullable(i.toDouble + 0.2d: java.lang.Double)) + .setMaybeBinaryColumn(nullable(ByteBuffer.wrap(s"val_$i".getBytes))) + .setMaybeStringColumn(nullable(s"val_$i")) + + .setStringsColumn(Seq.tabulate(3)(n => s"arr_${i + n}")) + .setStringToIntColumn( + mapAsJavaMap(Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap)) + .setComplexColumn(makeComplexColumn(i)) + + .build() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala new file mode 100644 index 0000000000000..b4cdfd9e98f6f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet +import java.io.File + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.schema.MessageType +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.QueryTest +import org.apache.spark.util.Utils + +abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll { + protected var parquetStore: File = _ + + override protected def beforeAll(): Unit = { + parquetStore = Utils.createTempDir(namePrefix = "parquet-compat_") + parquetStore.delete() + } + + override protected def afterAll(): Unit = { + Utils.deleteRecursively(parquetStore) + } + + def readParquetSchema(path: String): MessageType = { + val fsPath = new Path(path) + val fs = fsPath.getFileSystem(configuration) + val parquetFiles = fs.listStatus(fsPath).toSeq.filterNot(_.getPath.getName.startsWith("_")) + val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) + footers.head.getParquetMetadata.getFileMetaData.getSchema + } +} + +object ParquetCompatibilityTest { + def makeNullable[T <: AnyRef](i: Int)(f: => T): T = { + if (i % 3 == 0) null.asInstanceOf[T] else f + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala new file mode 100644 index 0000000000000..d22066cabc567 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.nio.ByteBuffer +import java.util.{List => JList, Map => JMap} + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.thrift.ThriftParquetWriter + +import org.apache.spark.sql.parquet.test.thrift.{Nested, ParquetThriftCompat, Suit} +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.{Row, SQLContext} + +class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest { + import ParquetCompatibilityTest._ + + override val sqlContext: SQLContext = TestSQLContext + + override protected def beforeAll(): Unit = { + super.beforeAll() + + val writer = + new ThriftParquetWriter[ParquetThriftCompat]( + new Path(parquetStore.getCanonicalPath), + classOf[ParquetThriftCompat], + CompressionCodecName.SNAPPY) + + (0 until 10).foreach(i => writer.write(makeParquetThriftCompat(i))) + writer.close() + } + + test("Read Parquet file generated by parquet-thrift") { + logInfo( + s"""Schema of the Parquet file written by parquet-thrift: + |${readParquetSchema(parquetStore.getCanonicalPath)} + """.stripMargin) + + checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), (0 until 10).map { i => + def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) + + Row( + i % 2 == 0, + i.toByte, + (i + 1).toShort, + i + 2, + i.toLong * 10, + i.toDouble + 0.2d, + // Thrift `BINARY` values are actually unencoded `STRING` values, and thus are always + // treated as `BINARY (UTF8)` in parquet-thrift, since parquet-thrift always assume + // Thrift `STRING`s are encoded using UTF-8. + s"val_$i", + s"val_$i", + // Thrift ENUM values are converted to Parquet binaries containing UTF-8 strings + Suit.values()(i % 4).name(), + + nullable(i % 2 == 0: java.lang.Boolean), + nullable(i.toByte: java.lang.Byte), + nullable((i + 1).toShort: java.lang.Short), + nullable(i + 2: Integer), + nullable((i * 10).toLong: java.lang.Long), + nullable(i.toDouble + 0.2d: java.lang.Double), + nullable(s"val_$i"), + nullable(s"val_$i"), + nullable(Suit.values()(i % 4).name()), + + Seq.tabulate(3)(n => s"arr_${i + n}"), + // Thrift `SET`s are converted to Parquet `LIST`s + Seq(i), + Seq.tabulate(3)(n => (i + n: Integer) -> s"val_${i + n}").toMap, + Seq.tabulate(3) { n => + (i + n) -> Seq.tabulate(3) { m => + Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") + } + }.toMap) + }) + } + + def makeParquetThriftCompat(i: Int): ParquetThriftCompat = { + def makeComplexColumn(i: Int): JMap[Integer, JList[Nested]] = { + mapAsJavaMap(Seq.tabulate(3) { n => + (i + n: Integer) -> seqAsJavaList(Seq.tabulate(3) { m => + new Nested( + seqAsJavaList(Seq.tabulate(3)(j => i + j + m)), + s"val_${i + m}") + }) + }.toMap) + } + + val value = + new ParquetThriftCompat( + i % 2 == 0, + i.toByte, + (i + 1).toShort, + i + 2, + i.toLong * 10, + i.toDouble + 0.2d, + ByteBuffer.wrap(s"val_$i".getBytes), + s"val_$i", + Suit.values()(i % 4), + + seqAsJavaList(Seq.tabulate(3)(n => s"arr_${i + n}")), + setAsJavaSet(Set(i)), + mapAsJavaMap(Seq.tabulate(3)(n => (i + n: Integer) -> s"val_${i + n}").toMap), + makeComplexColumn(i)) + + if (i % 3 == 0) { + value + } else { + value + .setMaybeBoolColumn(i % 2 == 0) + .setMaybeByteColumn(i.toByte) + .setMaybeShortColumn((i + 1).toShort) + .setMaybeIntColumn(i + 2) + .setMaybeLongColumn(i.toLong * 10) + .setMaybeDoubleColumn(i.toDouble + 0.2d) + .setMaybeBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes)) + .setMaybeStringColumn(s"val_$i") + .setMaybeEnumColumn(Suit.values()(i % 4)) + } + } +} diff --git a/sql/core/src/test/scripts/gen-code.sh b/sql/core/src/test/scripts/gen-code.sh new file mode 100755 index 0000000000000..5d8d8ad08555c --- /dev/null +++ b/sql/core/src/test/scripts/gen-code.sh @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +cd $(dirname $0)/.. +BASEDIR=`pwd` +cd - + +rm -rf $BASEDIR/gen-java +mkdir -p $BASEDIR/gen-java + +thrift\ + --gen java\ + -out $BASEDIR/gen-java\ + $BASEDIR/thrift/parquet-compat.thrift + +avro-tools idl $BASEDIR/avro/parquet-compat.avdl > $BASEDIR/avro/parquet-compat.avpr +avro-tools compile -string protocol $BASEDIR/avro/parquet-compat.avpr $BASEDIR/gen-java diff --git a/sql/core/src/test/thrift/parquet-compat.thrift b/sql/core/src/test/thrift/parquet-compat.thrift new file mode 100644 index 0000000000000..fa5ed8c62306a --- /dev/null +++ b/sql/core/src/test/thrift/parquet-compat.thrift @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace java org.apache.spark.sql.parquet.test.thrift + +enum Suit { + SPADES, + HEARTS, + DIAMONDS, + CLUBS +} + +struct Nested { + 1: required list nestedIntsColumn; + 2: required string nestedStringColumn; +} + +/** + * This is a test struct for testing parquet-thrift compatibility. + */ +struct ParquetThriftCompat { + 1: required bool boolColumn; + 2: required byte byteColumn; + 3: required i16 shortColumn; + 4: required i32 intColumn; + 5: required i64 longColumn; + 6: required double doubleColumn; + 7: required binary binaryColumn; + 8: required string stringColumn; + 9: required Suit enumColumn + + 10: optional bool maybeBoolColumn; + 11: optional byte maybeByteColumn; + 12: optional i16 maybeShortColumn; + 13: optional i32 maybeIntColumn; + 14: optional i64 maybeLongColumn; + 15: optional double maybeDoubleColumn; + 16: optional binary maybeBinaryColumn; + 17: optional string maybeStringColumn; + 18: optional Suit maybeEnumColumn; + + 19: required list stringsColumn; + 20: required set intSetColumn; + 21: required map intToStringColumn; + 22: required map> complexColumn; +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala new file mode 100644 index 0000000000000..bb5f1febe9ad4 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.parquet.ParquetCompatibilityTest +import org.apache.spark.sql.{Row, SQLConf, SQLContext} + +class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { + import ParquetCompatibilityTest.makeNullable + + override val sqlContext: SQLContext = TestHive + + override protected def beforeAll(): Unit = { + super.beforeAll() + + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { + withTempTable("data") { + sqlContext.sql( + s"""CREATE TABLE parquet_compat( + | bool_column BOOLEAN, + | byte_column TINYINT, + | short_column SMALLINT, + | int_column INT, + | long_column BIGINT, + | float_column FLOAT, + | double_column DOUBLE, + | + | strings_column ARRAY, + | int_to_string_column MAP + |) + |STORED AS PARQUET + |LOCATION '${parquetStore.getCanonicalPath}' + """.stripMargin) + + val schema = sqlContext.table("parquet_compat").schema + val rowRDD = sqlContext.sparkContext.parallelize(makeRows).coalesce(1) + sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") + sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") + } + } + } + + override protected def afterAll(): Unit = { + sqlContext.sql("DROP TABLE parquet_compat") + } + + test("Read Parquet file generated by parquet-hive") { + logInfo( + s"""Schema of the Parquet file written by parquet-hive: + |${readParquetSchema(parquetStore.getCanonicalPath)} + """.stripMargin) + + // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. + // Have to assume all BINARY values are strings here. + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { + checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), makeRows) + } + } + + def makeRows: Seq[Row] = { + (0 until 10).map { i => + def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) + + Row( + nullable(i % 2 == 0: java.lang.Boolean), + nullable(i.toByte: java.lang.Byte), + nullable((i + 1).toShort: java.lang.Short), + nullable(i + 2: Integer), + nullable(i.toLong * 10: java.lang.Long), + nullable(i.toFloat + 0.1f: java.lang.Float), + nullable(i.toDouble + 0.2d: java.lang.Double), + nullable(Seq.tabulate(3)(n => s"arr_${i + n}")), + nullable(Seq.tabulate(3)(n => (i + n: Integer) -> s"val_${i + n}").toMap)) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index c2e09800933b5..9d79a4b007d66 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -21,14 +21,16 @@ import java.io.File import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql._ import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} import org.apache.spark.sql.sources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, QueryTest, Row, SQLConf, SaveMode} import org.apache.spark.util.Utils // The data where the partitioning key exists only in the directory structure. @@ -685,6 +687,31 @@ class ParquetSourceSuiteBase extends ParquetPartitioningTest { sql("drop table spark_6016_fix") } + + test("SPARK-8811: compatibility with array of struct in Hive") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("array_of_struct") { + val conf = Seq( + HiveContext.CONVERT_METASTORE_PARQUET.key -> "false", + SQLConf.PARQUET_BINARY_AS_STRING.key -> "true", + SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key -> "true") + + withSQLConf(conf: _*) { + sql( + s"""CREATE TABLE array_of_struct + |STORED AS PARQUET LOCATION '$path' + |AS SELECT '1st', '2nd', ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) + """.stripMargin) + + checkAnswer( + sqlContext.read.parquet(path), + Row("1st", "2nd", Seq(Row("val_a", "val_b")))) + } + } + } + } } class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { @@ -762,7 +789,9 @@ class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase { /** * A collection of tests for parquet data with various forms of partitioning. */ -abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll { +abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll { + override def sqlContext: SQLContext = TestHive + var partitionedTableDir: File = null var normalTableDir: File = null var partitionedTableDirWithKey: File = null From 381cb161ba4e3a30f2da3c4ef4ee19869d51f101 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 8 Jul 2015 16:21:28 -0700 Subject: [PATCH 059/149] [SPARK-8068] [MLLIB] Add confusionMatrix method at class MulticlassMetrics in pyspark/mllib Add confusionMatrix method at class MulticlassMetrics in pyspark/mllib Author: Yanbo Liang Closes #7286 from yanboliang/spark-8068 and squashes the following commits: 6109fe1 [Yanbo Liang] Add confusionMatrix method at class MulticlassMetrics in pyspark/mllib --- python/pyspark/mllib/evaluation.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index c5cf3a4e7ff22..f21403707e12a 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -152,6 +152,10 @@ class MulticlassMetrics(JavaModelWrapper): >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) >>> metrics = MulticlassMetrics(predictionAndLabels) + >>> metrics.confusionMatrix().toArray() + array([[ 2., 1., 1.], + [ 1., 3., 0.], + [ 0., 0., 1.]]) >>> metrics.falsePositiveRate(0.0) 0.2... >>> metrics.precision(1.0) @@ -186,6 +190,13 @@ def __init__(self, predictionAndLabels): java_model = java_class(df._jdf) super(MulticlassMetrics, self).__init__(java_model) + def confusionMatrix(self): + """ + Returns confusion matrix: predicted classes are in columns, + they are ordered by class label ascending, as in "labels". + """ + return self.call("confusionMatrix") + def truePositiveRate(self, label): """ Returns true positive rate for a given label (category). From 8c32b2e870c7c250a63e838718df833edf6dea07 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Wed, 8 Jul 2015 16:27:11 -0700 Subject: [PATCH 060/149] [SPARK-8877] [MLLIB] Public API for association rule generation Adds FPGrowth.generateAssociationRules to public API for generating association rules after mining frequent itemsets. Author: Feynman Liang Closes #7271 from feynmanliang/SPARK-8877 and squashes the following commits: 83b8baf [Feynman Liang] Add API Doc 867abff [Feynman Liang] Add FPGrowth.generateAssociationRules and change access modifiers for AssociationRules --- .../spark/mllib/fpm/AssociationRules.scala | 5 ++- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 11 ++++- .../spark/mllib/fpm/FPGrowthSuite.scala | 42 +++++++++++++++++++ 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index 4a0f842f3338d..7e2bbfe31c1b7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD * association rules which have a single item as the consequent. */ @Experimental -class AssociationRules private ( +class AssociationRules private[fpm] ( private var minConfidence: Double) extends Logging with Serializable { /** @@ -45,6 +45,7 @@ class AssociationRules private ( * Sets the minimal confidence (default: `0.8`). */ def setMinConfidence(minConfidence: Double): this.type = { + require(minConfidence >= 0.0 && minConfidence <= 1.0) this.minConfidence = minConfidence this } @@ -91,7 +92,7 @@ object AssociationRules { * @tparam Item item type */ @Experimental - class Rule[Item] private[mllib] ( + class Rule[Item] private[fpm] ( val antecedent: Array[Item], val consequent: Array[Item], freqUnion: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 0da59e812d5f9..9cb9a00dbd9c7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -40,7 +40,16 @@ import org.apache.spark.storage.StorageLevel * @tparam Item item type */ @Experimental -class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable +class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { + /** + * Generates association rules for the [[Item]]s in [[freqItemsets]]. + * @param confidence minimal confidence of the rules produced + */ + def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = { + val associationRules = new AssociationRules(confidence) + associationRules.run(freqItemsets) + } +} /** * :: Experimental :: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index ddc296a428907..4a9bfdb348d9f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -132,6 +132,48 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model1.freqItemsets.count() === 625) } + test("FP-Growth String type association rule generation") { + val transactions = Seq( + "r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p") + .map(_.split(" ")) + val rdd = sc.parallelize(transactions, 2).cache() + + /* Verify results using the `R` code: + transactions = as(sapply( + list("r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p"), + FUN=function(x) strsplit(x," ",fixed=TRUE)), + "transactions") + ars = apriori(transactions, + parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2)) + arsDF = as(ars, "data.frame") + arsDF$support = arsDF$support * length(transactions) + names(arsDF)[names(arsDF) == "support"] = "freq" + > nrow(arsDF) + [1] 23 + > sum(arsDF$confidence == 1) + [1] 23 + */ + val rules = (new FPGrowth()) + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + .generateAssociationRules(0.9) + .collect() + + assert(rules.size === 23) + assert(rules.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23) + } + test("FP-Growth using Int type") { val transactions = Seq( "1 2 3", From f472b8cdc00839780dc79be0bbe53a098cde230c Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Wed, 8 Jul 2015 16:32:00 -0700 Subject: [PATCH 061/149] [SPARK-5016] [MLLIB] Distribute GMM mixture components to executors Distribute expensive portions of computation for Gaussian mixture components (in particular, pre-computation of `MultivariateGaussian.rootSigmaInv`, the inverse covariance matrix and covariance determinant) across executors. Repost of PR#4654. Notes for reviewers: * What should be the policy for when to distribute computation. Always? When numClusters > threshold? User-specified param? TODO: * Performance testing and comparison for large number of clusters Author: Feynman Liang Closes #7166 from feynmanliang/GMM_parallel_mixtures and squashes the following commits: 4f351fa [Feynman Liang] Update heuristic and scaladoc 5ea947e [Feynman Liang] Fix parallelization logic 00eb7db [Feynman Liang] Add helper method for GMM's M step, remove distributeGaussians flag e7c8127 [Feynman Liang] Add distributeGaussians flag and tests 1da3c7f [Feynman Liang] Distribute mixtures --- .../mllib/clustering/GaussianMixture.scala | 44 +++++++++++++++---- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index fc509d2ba1470..e459367333d26 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -140,6 +140,10 @@ class GaussianMixture private ( // Get length of the input vectors val d = breezeData.first().length + // Heuristic to distribute the computation of the [[MultivariateGaussian]]s, approximately when + // d > 25 except for when k is very small + val distributeGaussians = ((k - 1.0) / k) * d > 25 + // Determine initial weights and corresponding Gaussians. // If the user supplied an initial GMM, we use those values, otherwise // we start with uniform weights, a random mean from the data, and @@ -171,14 +175,25 @@ class GaussianMixture private ( // Create new distributions based on the partial assignments // (often referred to as the "M" step in literature) val sumWeights = sums.weights.sum - var i = 0 - while (i < k) { - val mu = sums.means(i) / sums.weights(i) - BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu), - Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix]) - weights(i) = sums.weights(i) / sumWeights - gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i)) - i = i + 1 + + if (distributeGaussians) { + val numPartitions = math.min(k, 1024) + val tuples = + Seq.tabulate(k)(i => (sums.means(i), sums.sigmas(i), sums.weights(i))) + val (ws, gs) = sc.parallelize(tuples, numPartitions).map { case (mean, sigma, weight) => + updateWeightsAndGaussians(mean, sigma, weight, sumWeights) + }.collect.unzip + Array.copy(ws, 0, weights, 0, ws.length) + Array.copy(gs, 0, gaussians, 0, gs.length) + } else { + var i = 0 + while (i < k) { + val (weight, gaussian) = + updateWeightsAndGaussians(sums.means(i), sums.sigmas(i), sums.weights(i), sumWeights) + weights(i) = weight + gaussians(i) = gaussian + i = i + 1 + } } llhp = llh // current becomes previous @@ -192,6 +207,19 @@ class GaussianMixture private ( /** Java-friendly version of [[run()]] */ def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd) + private def updateWeightsAndGaussians( + mean: BDV[Double], + sigma: BreezeMatrix[Double], + weight: Double, + sumWeights: Double): (Double, MultivariateGaussian) = { + val mu = (mean /= weight) + BLAS.syr(-weight, Vectors.fromBreeze(mu), + Matrices.fromBreeze(sigma).asInstanceOf[DenseMatrix]) + val newWeight = weight / sumWeights + val newGaussian = new MultivariateGaussian(mu, sigma / weight) + (newWeight, newGaussian) + } + /** Average of dense breeze vectors */ private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = { val v = BDV.zeros[Double](x(0).length) From 2a4f88b6c16f2991e63b17c0e103bcd79f04dbbc Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 8 Jul 2015 18:09:39 -0700 Subject: [PATCH 062/149] [SPARK-8914][SQL] Remove RDDApi As rxin suggested in #7298 , we should consider to remove `RDDApi`. Author: Kousuke Saruta Closes #7302 from sarutak/remove-rddapi and squashes the following commits: e495d35 [Kousuke Saruta] Fixed mima cb7ebb9 [Kousuke Saruta] Removed overriding RDDApi --- project/MimaExcludes.scala | 5 ++ .../org/apache/spark/sql/DataFrame.scala | 39 ++++++----- .../scala/org/apache/spark/sql/RDDApi.scala | 67 ------------------- 3 files changed, 24 insertions(+), 87 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7346d804632bc..57a86bf8deb64 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -70,7 +70,12 @@ object MimaExcludes { "org.apache.spark.mllib.linalg.Matrix.numNonzeros"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.linalg.Matrix.numActives") + ) ++ Seq( + // SPARK-8914 Remove RDDApi + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.RDDApi") ) + case v if v.startsWith("1.4") => Seq( MimaBuild.excludeSparkPackage("deploy"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f33e19a0cb7dd..eeefc85255d14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -115,8 +115,7 @@ private[sql] object DataFrame { @Experimental class DataFrame private[sql]( @transient val sqlContext: SQLContext, - @DeveloperApi @transient val queryExecution: SQLContext#QueryExecution) - extends RDDApi[Row] with Serializable { + @DeveloperApi @transient val queryExecution: SQLContext#QueryExecution) extends Serializable { /** * A constructor that automatically analyzes the logical plan. @@ -1320,14 +1319,14 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - override def first(): Row = head() + def first(): Row = head() /** * Returns a new RDD by applying a function to all rows of this DataFrame. * @group rdd * @since 1.3.0 */ - override def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f) + def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f) /** * Returns a new RDD by first applying a function to all rows of this [[DataFrame]], @@ -1335,14 +1334,14 @@ class DataFrame private[sql]( * @group rdd * @since 1.3.0 */ - override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) + def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) /** * Returns a new RDD by applying a function to each partition of this DataFrame. * @group rdd * @since 1.3.0 */ - override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { + def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { rdd.mapPartitions(f) } @@ -1351,49 +1350,49 @@ class DataFrame private[sql]( * @group rdd * @since 1.3.0 */ - override def foreach(f: Row => Unit): Unit = rdd.foreach(f) + def foreach(f: Row => Unit): Unit = rdd.foreach(f) /** * Applies a function f to each partition of this [[DataFrame]]. * @group rdd * @since 1.3.0 */ - override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f) + def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f) /** * Returns the first `n` rows in the [[DataFrame]]. * @group action * @since 1.3.0 */ - override def take(n: Int): Array[Row] = head(n) + def take(n: Int): Array[Row] = head(n) /** * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. * @group action * @since 1.3.0 */ - override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() + def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() /** * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. * @group action * @since 1.3.0 */ - override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() : _*) + def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() : _*) /** * Returns the number of rows in the [[DataFrame]]. * @group action * @since 1.3.0 */ - override def count(): Long = groupBy().count().collect().head.getLong(0) + def count(): Long = groupBy().count().collect().head.getLong(0) /** * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. * @group rdd * @since 1.3.0 */ - override def repartition(numPartitions: Int): DataFrame = { + def repartition(numPartitions: Int): DataFrame = { Repartition(numPartitions, shuffle = true, logicalPlan) } @@ -1405,7 +1404,7 @@ class DataFrame private[sql]( * @group rdd * @since 1.4.0 */ - override def coalesce(numPartitions: Int): DataFrame = { + def coalesce(numPartitions: Int): DataFrame = { Repartition(numPartitions, shuffle = false, logicalPlan) } @@ -1415,13 +1414,13 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - override def distinct(): DataFrame = dropDuplicates() + def distinct(): DataFrame = dropDuplicates() /** * @group basic * @since 1.3.0 */ - override def persist(): this.type = { + def persist(): this.type = { sqlContext.cacheManager.cacheQuery(this) this } @@ -1430,13 +1429,13 @@ class DataFrame private[sql]( * @group basic * @since 1.3.0 */ - override def cache(): this.type = persist() + def cache(): this.type = persist() /** * @group basic * @since 1.3.0 */ - override def persist(newLevel: StorageLevel): this.type = { + def persist(newLevel: StorageLevel): this.type = { sqlContext.cacheManager.cacheQuery(this, None, newLevel) this } @@ -1445,7 +1444,7 @@ class DataFrame private[sql]( * @group basic * @since 1.3.0 */ - override def unpersist(blocking: Boolean): this.type = { + def unpersist(blocking: Boolean): this.type = { sqlContext.cacheManager.tryUncacheQuery(this, blocking) this } @@ -1454,7 +1453,7 @@ class DataFrame private[sql]( * @group basic * @since 1.3.0 */ - override def unpersist(): this.type = unpersist(blocking = false) + def unpersist(): this.type = unpersist(blocking = false) ///////////////////////////////////////////////////////////////////////////// // I/O diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala b/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala deleted file mode 100644 index 63dbab19947c0..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import scala.reflect.ClassTag - -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel - - -/** - * An internal interface defining the RDD-like methods for [[DataFrame]]. - * Please use [[DataFrame]] directly, and do NOT use this. - */ -private[sql] trait RDDApi[T] { - - def cache(): this.type - - def persist(): this.type - - def persist(newLevel: StorageLevel): this.type - - def unpersist(): this.type - - def unpersist(blocking: Boolean): this.type - - def map[R: ClassTag](f: T => R): RDD[R] - - def flatMap[R: ClassTag](f: T => TraversableOnce[R]): RDD[R] - - def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R] - - def foreach(f: T => Unit): Unit - - def foreachPartition(f: Iterator[T] => Unit): Unit - - def take(n: Int): Array[T] - - def collect(): Array[T] - - def collectAsList(): java.util.List[T] - - def count(): Long - - def first(): T - - def repartition(numPartitions: Int): DataFrame - - def coalesce(numPartitions: Int): DataFrame - - def distinct: DataFrame -} From 74d8d3d928cc9a7386b68588ac89ae042847d146 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 8 Jul 2015 18:22:53 -0700 Subject: [PATCH 063/149] [SPARK-8450] [SQL] [PYSARK] cleanup type converter for Python DataFrame This PR fixes the converter for Python DataFrame, especially for DecimalType Closes #7106 Author: Davies Liu Closes #7131 from davies/decimal_python and squashes the following commits: 4d3c234 [Davies Liu] Merge branch 'master' of github.com:apache/spark into decimal_python 20531d6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into decimal_python 7d73168 [Davies Liu] fix conflit 6cdd86a [Davies Liu] Merge branch 'master' of github.com:apache/spark into decimal_python 7104e97 [Davies Liu] improve type infer 9cd5a21 [Davies Liu] run python tests with SPARK_PREPEND_CLASSES 829a05b [Davies Liu] fix UDT in python c99e8c5 [Davies Liu] fix mima c46814a [Davies Liu] convert decimal for Python DataFrames --- .../apache/spark/mllib/linalg/Matrices.scala | 10 +- .../apache/spark/mllib/linalg/Vectors.scala | 16 +--- project/MimaExcludes.scala | 5 +- python/pyspark/sql/tests.py | 13 +++ python/pyspark/sql/types.py | 4 + python/run-tests.py | 3 +- .../org/apache/spark/sql/DataFrame.scala | 4 +- .../org/apache/spark/sql/SQLContext.scala | 28 +----- .../spark/sql/execution/pythonUDFs.scala | 95 ++++++++++--------- 9 files changed, 84 insertions(+), 94 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 75e7004464af9..0df07663405a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.Row -import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ /** * Trait for a local matrix. @@ -147,7 +147,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { )) } - override def serialize(obj: Any): Row = { + override def serialize(obj: Any): InternalRow = { val row = new GenericMutableRow(7) obj match { case sm: SparseMatrix => @@ -173,9 +173,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { override def deserialize(datum: Any): Matrix = { datum match { - // TODO: something wrong with UDT serialization, should never happen. - case m: Matrix => m - case row: Row => + case row: InternalRow => require(row.length == 7, s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7") val tpe = row.getByte(0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index c9c27425d2877..e048b01d92462 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -28,7 +28,7 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.util.NumericParser -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types._ @@ -175,7 +175,7 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) } - override def serialize(obj: Any): Row = { + override def serialize(obj: Any): InternalRow = { obj match { case SparseVector(size, indices, values) => val row = new GenericMutableRow(4) @@ -191,17 +191,12 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { row.setNullAt(2) row.update(3, values.toSeq) row - // TODO: There are bugs in UDT serialization because we don't have a clear separation between - // TODO: internal SQL types and language specific types (including UDT). UDT serialize and - // TODO: deserialize may get called twice. See SPARK-7186. - case row: Row => - row } } override def deserialize(datum: Any): Vector = { datum match { - case row: Row => + case row: InternalRow => require(row.length == 4, s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") val tpe = row.getByte(0) @@ -215,11 +210,6 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { val values = row.getAs[Iterable[Double]](3).toArray new DenseVector(values) } - // TODO: There are bugs in UDT serialization because we don't have a clear separation between - // TODO: internal SQL types and language specific types (including UDT). UDT serialize and - // TODO: deserialize may get called twice. See SPARK-7186. - case v: Vector => - v } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 57a86bf8deb64..821aadd477ef3 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -63,7 +63,10 @@ object MimaExcludes { // SQL execution is considered private. excludePackage("org.apache.spark.sql.execution"), // Parquet support is considered private. - excludePackage("org.apache.spark.sql.parquet") + excludePackage("org.apache.spark.sql.parquet"), + // local function inside a method + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1") ) ++ Seq( // SPARK-8479 Add numNonzeros and numActives to Matrix. ProblemFilters.exclude[MissingMethodProblem]( diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 333378c7f1854..66827d48850d9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -700,6 +700,19 @@ def test_time_with_timezone(self): self.assertTrue(now - now1 < datetime.timedelta(0.001)) self.assertTrue(now - utcnow1 < datetime.timedelta(0.001)) + def test_decimal(self): + from decimal import Decimal + schema = StructType([StructField("decimal", DecimalType(10, 5))]) + df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema) + row = df.select(df.decimal + 1).first() + self.assertEqual(row[0], Decimal("4.14159")) + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.write.parquet(tmpPath) + df2 = self.sqlCtx.read.parquet(tmpPath) + row = df2.first() + self.assertEqual(row[0], Decimal("3.14159")) + def test_dropna(self): schema = StructType([ StructField("name", StringType(), True), diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 160df40d65cc1..7e64cb0b54dba 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1069,6 +1069,10 @@ def _verify_type(obj, dataType): if obj is None: return + # StringType can work with any types + if isinstance(dataType, StringType): + return + if isinstance(dataType, UserDefinedType): if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): raise ValueError("%r is not an instance of type %r" % (obj, dataType)) diff --git a/python/run-tests.py b/python/run-tests.py index 7638854def2e8..cc560779373b3 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -72,7 +72,8 @@ def print_red(text): def run_individual_python_test(test_name, pyspark_python): - env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)} + env = dict(os.environ) + env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}) LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index eeefc85255d14..d9f987ae0252f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1549,8 +1549,8 @@ class DataFrame private[sql]( * Converts a JavaRDD to a PythonRDD. */ protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + val structType = schema // capture it for closure + val jrdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)).toJavaRDD() SerDeUtil.javaToPython(jrdd) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 079f31ab8fe6d..477dea9164726 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1044,33 +1044,7 @@ class SQLContext(@transient val sparkContext: SparkContext) rdd: RDD[Array[Any]], schema: StructType): DataFrame = { - def needsConversion(dataType: DataType): Boolean = dataType match { - case ByteType => true - case ShortType => true - case LongType => true - case FloatType => true - case DateType => true - case TimestampType => true - case StringType => true - case ArrayType(_, _) => true - case MapType(_, _, _) => true - case StructType(_) => true - case udt: UserDefinedType[_] => needsConversion(udt.sqlType) - case other => false - } - - val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) { - rdd.map(m => m.zip(schema.fields).map { - case (value, field) => EvaluatePython.fromJava(value, field.dataType) - }) - } else { - rdd - } - - val rowRdd = convertedRdd.mapPartitions { iter => - iter.map { m => new GenericInternalRow(m): InternalRow} - } - + val rowRdd = rdd.map(r => EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 6946e798b71b0..1c8130b07c7fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -24,20 +24,19 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{Pickler, Unpickler} -import org.apache.spark.{Accumulator, Logging => SparkLogging} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.{PythonBroadcast, PythonRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{Accumulator, Logging => SparkLogging} /** * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. @@ -125,59 +124,86 @@ object EvaluatePython { new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) /** - * Helper for converting a Scala object to a java suitable for pyspark serialization. + * Helper for converting from Catalyst type to java type suitable for Pyrolite. */ def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { case (null, _) => null - case (row: Row, struct: StructType) => + case (row: InternalRow, struct: StructType) => val fields = struct.fields.map(field => field.dataType) - row.toSeq.zip(fields).map { - case (obj, dataType) => toJava(obj, dataType) - }.toArray + rowToArray(row, fields) case (seq: Seq[Any], array: ArrayType) => seq.map(x => toJava(x, array.elementType)).asJava - case (list: JList[_], array: ArrayType) => - list.map(x => toJava(x, array.elementType)).asJava - case (arr, array: ArrayType) if arr.getClass.isArray => - arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) case (obj: Map[_, _], mt: MapType) => obj.map { case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType)) }.asJava - case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) + case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType) case (date: Int, DateType) => DateTimeUtils.toJavaDate(date) case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t) + + case (d: Decimal, _) => d.toJavaBigDecimal + case (s: UTF8String, StringType) => s.toString - // Pyrolite can handle Timestamp and Decimal case (other, _) => other } /** * Convert Row into Java Array (for pickled into Python) */ - def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = { + def rowToArray(row: InternalRow, fields: Seq[DataType]): Array[Any] = { // TODO: this is slow! row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray } - // Converts value to the type specified by the data type. - // Because Python does not have data types for TimestampType, FloatType, ShortType, and - // ByteType, we need to explicitly convert values in columns of these data types to the desired - // JVM data types. + /** + * Converts `obj` to the type specified by the data type, or returns null if the type of obj is + * unexpected. Because Python doesn't enforce the type. + */ def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - // TODO: We should check nullable case (null, _) => null + case (c: Boolean, BooleanType) => c + + case (c: Int, ByteType) => c.toByte + case (c: Long, ByteType) => c.toByte + + case (c: Int, ShortType) => c.toShort + case (c: Long, ShortType) => c.toShort + + case (c: Int, IntegerType) => c + case (c: Long, IntegerType) => c.toInt + + case (c: Int, LongType) => c.toLong + case (c: Long, LongType) => c + + case (c: Double, FloatType) => c.toFloat + + case (c: Double, DoubleType) => c + + case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c) + + case (c: Int, DateType) => c + + case (c: Long, TimestampType) => c + + case (c: String, StringType) => UTF8String.fromString(c) + case (c, StringType) => + // If we get here, c is not a string. Call toString on it. + UTF8String.fromString(c.toString) + + case (c: String, BinaryType) => c.getBytes("utf-8") + case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c + case (c: java.util.List[_], ArrayType(elementType, _)) => - c.map { e => fromJava(e, elementType)}: Seq[Any] + c.map { e => fromJava(e, elementType)}.toSeq case (c, ArrayType(elementType, _)) if c.getClass.isArray => - c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)): Seq[Any] + c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)).toSeq case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { case (key, value) => (fromJava(key, keyType), fromJava(value, valueType)) @@ -188,30 +214,11 @@ object EvaluatePython { case (e, f) => fromJava(e, f.dataType) }) - case (c: java.util.Calendar, DateType) => - DateTimeUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis)) - - case (c: java.util.Calendar, TimestampType) => - c.getTimeInMillis * 10000L - case (t: java.sql.Timestamp, TimestampType) => - DateTimeUtils.fromJavaTimestamp(t) - - case (_, udt: UserDefinedType[_]) => - fromJava(obj, udt.sqlType) - - case (c: Int, ByteType) => c.toByte - case (c: Long, ByteType) => c.toByte - case (c: Int, ShortType) => c.toShort - case (c: Long, ShortType) => c.toShort - case (c: Long, IntegerType) => c.toInt - case (c: Int, LongType) => c.toLong - case (c: Double, FloatType) => c.toFloat - case (c: String, StringType) => UTF8String.fromString(c) - case (c, StringType) => - // If we get here, c is not a string. Call toString on it. - UTF8String.fromString(c.toString) + case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType) - case (c, _) => c + // all other unexpected type should be null, or we will have runtime exception + // TODO(davies): we could improve this by try to cast the object to expected type + case (c, _) => null } } From 28fa01e2ba146e823489f6d81c5eb3a76b20c71f Mon Sep 17 00:00:00 2001 From: Jonathan Alter Date: Thu, 9 Jul 2015 03:28:51 +0100 Subject: [PATCH 064/149] [SPARK-8927] [DOCS] Format wrong for some config descriptions A couple descriptions were not inside `` and were being displayed immediately under the section title instead of in their row. Author: Jonathan Alter Closes #7292 from jonalter/docs-config and squashes the following commits: 5ce1570 [Jonathan Alter] [DOCS] Format wrong for some config descriptions --- docs/configuration.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index bebaf6f62e90a..892c02b27df32 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1007,9 +1007,9 @@ Apart from these, the following properties are also available, and may be useful spark.rpc.numRetries 3 + Number of times to retry before an RPC task gives up. An RPC task will run at most times of this number. - @@ -1029,8 +1029,8 @@ Apart from these, the following properties are also available, and may be useful spark.rpc.lookupTimeout 120s - Duration for an RPC remote endpoint lookup operation to wait before timing out. + Duration for an RPC remote endpoint lookup operation to wait before timing out. From a290814877308c6fa9b0f78b1a81145db7651ca4 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 8 Jul 2015 20:20:17 -0700 Subject: [PATCH 065/149] [SPARK-8866][SQL] use 1us precision for timestamp type JIRA: https://issues.apache.org/jira/browse/SPARK-8866 Author: Yijie Shen Closes #7283 from yijieshen/micro_timestamp and squashes the following commits: dc735df [Yijie Shen] update CastSuite to avoid round error 714eaea [Yijie Shen] add timestamp_udf into blacklist due to precision lose c3ca2f4 [Yijie Shen] fix unhandled case in CurrentTimestamp 8d4aa6b [Yijie Shen] use 1us precision for timestamp type --- python/pyspark/sql/types.py | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 18 ++++----- .../expressions/datetimeFunctions.scala | 2 +- .../sql/catalyst/util/DateTimeUtils.scala | 38 +++++++++---------- .../sql/catalyst/expressions/CastSuite.scala | 10 ++--- .../catalyst/util/DateTimeUtilsSuite.scala | 8 ++-- .../apache/spark/sql/json/JacksonParser.scala | 4 +- .../org/apache/spark/sql/json/JsonRDD.scala | 6 +-- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 +- .../execution/HiveCompatibilitySuite.scala | 6 +-- .../spark/sql/hive/HiveInspectors.scala | 4 +- 11 files changed, 50 insertions(+), 50 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 7e64cb0b54dba..fecfe6d71e9a7 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -775,7 +775,7 @@ def to_posix_timstamp(dt): if dt: seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo else time.mktime(dt.timetuple())) - return int(seconds * 1e7 + dt.microsecond * 10) + return int(seconds * 1e6 + dt.microsecond) return to_posix_timstamp else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 662ceeca7782d..567feca7136f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -186,7 +186,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case ByteType => buildCast[Byte](_, b => longToTimestamp(b.toLong)) case DateType => - buildCast[Int](_, d => DateTimeUtils.daysToMillis(d) * 10000) + buildCast[Int](_, d => DateTimeUtils.daysToMillis(d) * 1000) // TimestampWritable.decimalToTimestamp case DecimalType() => buildCast[Decimal](_, d => decimalToTimestamp(d)) @@ -207,16 +207,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } private[this] def decimalToTimestamp(d: Decimal): Long = { - (d.toBigDecimal * 10000000L).longValue() + (d.toBigDecimal * 1000000L).longValue() } - // converting milliseconds to 100ns - private[this] def longToTimestamp(t: Long): Long = t * 10000L - // converting 100ns to seconds - private[this] def timestampToLong(ts: Long): Long = math.floor(ts.toDouble / 10000000L).toLong - // converting 100ns to seconds in double + // converting milliseconds to us + private[this] def longToTimestamp(t: Long): Long = t * 1000L + // converting us to seconds + private[this] def timestampToLong(ts: Long): Long = math.floor(ts.toDouble / 1000000L).toLong + // converting us to seconds in double private[this] def timestampToDouble(ts: Long): Double = { - ts / 10000000.0 + ts / 1000000.0 } // DateConverter @@ -229,7 +229,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. - buildCast[Long](_, t => DateTimeUtils.millisToDays(t / 10000L)) + buildCast[Long](_, t => DateTimeUtils.millisToDays(t / 1000L)) // Hive throws this exception as a Semantic Exception // It is never possible to compare result when hive return with exception, // so we can return null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index a492b966a5e31..dd5ec330a771b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -51,6 +51,6 @@ case class CurrentTimestamp() extends LeafExpression { override def dataType: DataType = TimestampType override def eval(input: InternalRow): Any = { - System.currentTimeMillis() * 10000L + System.currentTimeMillis() * 1000L } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 4269ad5d56737..c1ddee3ef0230 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -34,8 +34,8 @@ object DateTimeUtils { // see http://stackoverflow.com/questions/466321/convert-unix-timestamp-to-julian final val JULIAN_DAY_OF_EPOCH = 2440587 // and .5 final val SECONDS_PER_DAY = 60 * 60 * 24L - final val HUNDRED_NANOS_PER_SECOND = 1000L * 1000L * 10L - final val NANOS_PER_SECOND = HUNDRED_NANOS_PER_SECOND * 100 + final val MICROS_PER_SECOND = 1000L * 1000L + final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. @@ -77,8 +77,8 @@ object DateTimeUtils { threadLocalDateFormat.get.format(toJavaDate(days)) // Converts Timestamp to string according to Hive TimestampWritable convention. - def timestampToString(num100ns: Long): String = { - val ts = toJavaTimestamp(num100ns) + def timestampToString(us: Long): String = { + val ts = toJavaTimestamp(us) val timestampString = ts.toString val formatted = threadLocalTimestampFormat.get.format(ts) @@ -132,52 +132,52 @@ object DateTimeUtils { } /** - * Returns a java.sql.Timestamp from number of 100ns since epoch. + * Returns a java.sql.Timestamp from number of micros since epoch. */ - def toJavaTimestamp(num100ns: Long): Timestamp = { + def toJavaTimestamp(us: Long): Timestamp = { // setNanos() will overwrite the millisecond part, so the milliseconds should be // cut off at seconds - var seconds = num100ns / HUNDRED_NANOS_PER_SECOND - var nanos = num100ns % HUNDRED_NANOS_PER_SECOND + var seconds = us / MICROS_PER_SECOND + var micros = us % MICROS_PER_SECOND // setNanos() can not accept negative value - if (nanos < 0) { - nanos += HUNDRED_NANOS_PER_SECOND + if (micros < 0) { + micros += MICROS_PER_SECOND seconds -= 1 } val t = new Timestamp(seconds * 1000) - t.setNanos(nanos.toInt * 100) + t.setNanos(micros.toInt * 1000) t } /** - * Returns the number of 100ns since epoch from java.sql.Timestamp. + * Returns the number of micros since epoch from java.sql.Timestamp. */ def fromJavaTimestamp(t: Timestamp): Long = { if (t != null) { - t.getTime() * 10000L + (t.getNanos().toLong / 100) % 10000L + t.getTime() * 1000L + (t.getNanos().toLong / 1000) % 1000L } else { 0L } } /** - * Returns the number of 100ns (hundred of nanoseconds) since epoch from Julian day + * Returns the number of microseconds since epoch from Julian day * and nanoseconds in a day */ def fromJulianDay(day: Int, nanoseconds: Long): Long = { // use Long to avoid rounding errors val seconds = (day - JULIAN_DAY_OF_EPOCH).toLong * SECONDS_PER_DAY - SECONDS_PER_DAY / 2 - seconds * HUNDRED_NANOS_PER_SECOND + nanoseconds / 100L + seconds * MICROS_PER_SECOND + nanoseconds / 1000L } /** - * Returns Julian day and nanoseconds in a day from the number of 100ns (hundred of nanoseconds) + * Returns Julian day and nanoseconds in a day from the number of microseconds */ - def toJulianDay(num100ns: Long): (Int, Long) = { - val seconds = num100ns / HUNDRED_NANOS_PER_SECOND + SECONDS_PER_DAY / 2 + def toJulianDay(us: Long): (Int, Long) = { + val seconds = us / MICROS_PER_SECOND + SECONDS_PER_DAY / 2 val day = seconds / SECONDS_PER_DAY + JULIAN_DAY_OF_EPOCH val secondsInDay = seconds % SECONDS_PER_DAY - val nanos = (num100ns % HUNDRED_NANOS_PER_SECOND) * 100L + val nanos = (us % MICROS_PER_SECOND) * 1000L (day.toInt, secondsInDay * NANOS_PER_SECOND + nanos) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 518961e38396f..919fdd470b79a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -293,15 +293,15 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast from timestamp") { - val millis = 15 * 1000 + 2 - val seconds = millis * 1000 + 2 + val millis = 15 * 1000 + 3 + val seconds = millis * 1000 + 3 val ts = new Timestamp(millis) val tss = new Timestamp(seconds) checkEvaluation(cast(ts, ShortType), 15.toShort) checkEvaluation(cast(ts, IntegerType), 15) checkEvaluation(cast(ts, LongType), 15.toLong) - checkEvaluation(cast(ts, FloatType), 15.002f) - checkEvaluation(cast(ts, DoubleType), 15.002) + checkEvaluation(cast(ts, FloatType), 15.003f) + checkEvaluation(cast(ts, DoubleType), 15.003) checkEvaluation(cast(cast(tss, ShortType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) checkEvaluation(cast(cast(tss, IntegerType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) @@ -317,7 +317,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { Decimal(1)) // A test for higher precision than millis - checkEvaluation(cast(cast(0.0000001, TimestampType), DoubleType), 0.0000001) + checkEvaluation(cast(cast(0.000001, TimestampType), DoubleType), 0.000001) checkEvaluation(cast(Double.NaN, TimestampType), null) checkEvaluation(cast(1.0 / 0.0, TimestampType), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 1d4a60c81efc5..f63ac191e7366 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -24,11 +24,11 @@ import org.apache.spark.SparkFunSuite class DateTimeUtilsSuite extends SparkFunSuite { - test("timestamp and 100ns") { + test("timestamp and us") { val now = new Timestamp(System.currentTimeMillis()) - now.setNanos(100) + now.setNanos(1000) val ns = DateTimeUtils.fromJavaTimestamp(now) - assert(ns % 10000000L === 1) + assert(ns % 1000000L === 1) assert(DateTimeUtils.toJavaTimestamp(ns) === now) List(-111111111111L, -1L, 0, 1L, 111111111111L).foreach { t => @@ -38,7 +38,7 @@ class DateTimeUtilsSuite extends SparkFunSuite { } } - test("100ns and julian day") { + test("us and julian day") { val (d, ns) = DateTimeUtils.toJulianDay(0) assert(d === DateTimeUtils.JULIAN_DAY_OF_EPOCH) assert(ns === DateTimeUtils.SECONDS_PER_DAY / 2 * DateTimeUtils.NANOS_PER_SECOND) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index 4b8ab63b5ab39..381e7ed54428f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -67,10 +67,10 @@ private[sql] object JacksonParser { DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) case (VALUE_STRING, TimestampType) => - DateTimeUtils.stringToTime(parser.getText).getTime * 10000L + DateTimeUtils.stringToTime(parser.getText).getTime * 1000L case (VALUE_NUMBER_INT, TimestampType) => - parser.getLongValue * 10000L + parser.getLongValue * 1000L case (_, StringType) => val writer = new ByteArrayOutputStream() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 01ba05cbd14f1..b392a51bf7dce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -401,9 +401,9 @@ private[sql] object JsonRDD extends Logging { private def toTimestamp(value: Any): Long = { value match { - case value: java.lang.Integer => value.asInstanceOf[Int].toLong * 10000L - case value: java.lang.Long => value * 10000L - case value: java.lang.String => DateTimeUtils.stringToTime(value).getTime * 10000L + case value: java.lang.Integer => value.asInstanceOf[Int].toLong * 1000L + case value: java.lang.Long => value * 1000L + case value: java.lang.String => DateTimeUtils.stringToTime(value).getTime * 1000L } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 69ab1c292d221..566a52dc1b784 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -326,7 +326,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(cal.get(Calendar.HOUR) === 11) assert(cal.get(Calendar.MINUTE) === 22) assert(cal.get(Calendar.SECOND) === 33) - assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543500) + assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543000) } test("test DATE types") { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 415a81644c58f..c884c399281a8 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -254,9 +254,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // the answer is sensitive for jdk version "udf_java_method", - // Spark SQL use Long for TimestampType, lose the precision under 100ns + // Spark SQL use Long for TimestampType, lose the precision under 1us "timestamp_1", - "timestamp_2" + "timestamp_2", + "timestamp_udf" ) /** @@ -803,7 +804,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "timestamp_comparison", "timestamp_lazy", "timestamp_null", - "timestamp_udf", "touch", "transform_ppr1", "transform_ppr2", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 4cba17524af6c..a8f2ee37cb8ed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -267,7 +267,7 @@ private[hive] trait HiveInspectors { poi.getWritableConstantValue.getHiveDecimal) case poi: WritableConstantTimestampObjectInspector => val t = poi.getWritableConstantValue - t.getSeconds * 10000000L + t.getNanos / 100L + t.getSeconds * 1000000L + t.getNanos / 1000L case poi: WritableConstantIntObjectInspector => poi.getWritableConstantValue.get() case poi: WritableConstantDoubleObjectInspector => @@ -332,7 +332,7 @@ private[hive] trait HiveInspectors { case x: DateObjectInspector => DateTimeUtils.fromJavaDate(x.getPrimitiveJavaObject(data)) case x: TimestampObjectInspector if x.preferWritable() => val t = x.getPrimitiveWritableObject(data) - t.getSeconds * 10000000L + t.getNanos / 100 + t.getSeconds * 1000000L + t.getNanos / 1000L case ti: TimestampObjectInspector => DateTimeUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data)) case _ => pi.getPrimitiveJavaObject(data) From b55499a44ab74e33378211fb0d6940905d7c6318 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 8 Jul 2015 20:28:05 -0700 Subject: [PATCH 066/149] [SPARK-8932] Support copy() for UnsafeRows that do not use ObjectPools We call Row.copy() in many places throughout SQL but UnsafeRow currently throws UnsupportedOperationException when copy() is called. Supporting copying when ObjectPool is used may be difficult, since we may need to handle deep-copying of objects in the pool. In addition, this copy() method needs to produce a self-contained row object which may be passed around / buffered by downstream code which does not understand the UnsafeRow format. In the long run, we'll need to figure out how to handle the ObjectPool corner cases, but this may be unnecessary if other changes are made. Therefore, in order to unblock my sort patch (#6444) I propose that we support copy() for the cases where UnsafeRow does not use an ObjectPool and continue to throw UnsupportedOperationException when an ObjectPool is used. This patch accomplishes this by modifying UnsafeRow so that it knows the size of the row's backing data in order to be able to copy it into a byte array. Author: Josh Rosen Closes #7306 from JoshRosen/SPARK-8932 and squashes the following commits: 338e6bf [Josh Rosen] Support copy for UnsafeRows that do not use ObjectPools. --- .../UnsafeFixedWidthAggregationMap.java | 12 +++-- .../sql/catalyst/expressions/UnsafeRow.java | 32 +++++++++++- .../expressions/UnsafeRowConverter.scala | 10 +++- .../expressions/UnsafeRowConverterSuite.scala | 52 ++++++++++++++----- 4 files changed, 87 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 1e79f4b2e88e5..79d55b36dab01 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -120,9 +120,11 @@ public UnsafeFixedWidthAggregationMap( this.bufferPool = new ObjectPool(initialCapacity); InternalRow initRow = initProjection.apply(emptyRow); - this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)]; + int emptyBufferSize = bufferConverter.getSizeRequirement(initRow); + this.emptyBuffer = new byte[emptyBufferSize]; int writtenLength = bufferConverter.writeRow( - initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool); + initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, emptyBufferSize, + bufferPool); assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!"; // re-use the empty buffer only when there is no object saved in pool. reuseEmptyBuffer = bufferPool.size() == 0; @@ -142,6 +144,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { groupingKey, groupingKeyConversionScratchSpace, PlatformDependent.BYTE_ARRAY_OFFSET, + groupingKeySize, keyPool); assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; @@ -157,7 +160,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { // There is some objects referenced by emptyBuffer, so generate a new one InternalRow initRow = initProjection.apply(emptyRow); bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, - bufferPool); + groupingKeySize, bufferPool); } loc.putNewKey( groupingKeyConversionScratchSpace, @@ -175,6 +178,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { address.getBaseObject(), address.getBaseOffset(), bufferConverter.numFields(), + loc.getValueLength(), bufferPool ); return currentBuffer; @@ -214,12 +218,14 @@ public MapEntry next() { keyAddress.getBaseObject(), keyAddress.getBaseOffset(), keyConverter.numFields(), + loc.getKeyLength(), keyPool ); entry.value.pointTo( valueAddress.getBaseObject(), valueAddress.getBaseOffset(), bufferConverter.numFields(), + loc.getValueLength(), bufferPool ); return entry; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index aeb64b045812f..edb7202245289 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -68,6 +68,9 @@ public final class UnsafeRow extends MutableRow { /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ private int numFields; + /** The size of this row's backing data, in bytes) */ + private int sizeInBytes; + public int length() { return numFields; } /** The width of the null tracking bit set, in bytes */ @@ -95,14 +98,17 @@ public UnsafeRow() { } * @param baseObject the base object * @param baseOffset the offset within the base object * @param numFields the number of fields in this row + * @param sizeInBytes the size of this row's backing data, in bytes * @param pool the object pool to hold arbitrary objects */ - public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) { + public void pointTo( + Object baseObject, long baseOffset, int numFields, int sizeInBytes, ObjectPool pool) { assert numFields >= 0 : "numFields should >= 0"; this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; this.numFields = numFields; + this.sizeInBytes = sizeInBytes; this.pool = pool; } @@ -336,9 +342,31 @@ public double getDouble(int i) { } } + /** + * Copies this row, returning a self-contained UnsafeRow that stores its data in an internal + * byte array rather than referencing data stored in a data page. + *

+ * This method is only supported on UnsafeRows that do not use ObjectPools. + */ @Override public InternalRow copy() { - throw new UnsupportedOperationException(); + if (pool != null) { + throw new UnsupportedOperationException( + "Copy is not supported for UnsafeRows that use object pools"); + } else { + UnsafeRow rowCopy = new UnsafeRow(); + final byte[] rowDataCopy = new byte[sizeInBytes]; + PlatformDependent.copyMemory( + baseObject, + baseOffset, + rowDataCopy, + PlatformDependent.BYTE_ARRAY_OFFSET, + sizeInBytes + ); + rowCopy.pointTo( + rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes, null); + return rowCopy; + } } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 1f395497a9839..6af5e6200e57b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -70,10 +70,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { * @param row the row to convert * @param baseObject the base object of the destination address * @param baseOffset the base offset of the destination address + * @param rowLengthInBytes the length calculated by `getSizeRequirement(row)` * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. */ - def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: ObjectPool): Int = { - unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool) + def writeRow( + row: InternalRow, + baseObject: Object, + baseOffset: Long, + rowLengthInBytes: Int, + pool: ObjectPool): Int = { + unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes, pool) if (writers.length > 0) { // zero-out the bitset diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 96d4e64ea344a..d00aeb4dfbf47 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -44,19 +44,32 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(row) assert(sizeRequired === 8 + (3 * 8)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + val numBytesWritten = + converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.pointTo( + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getLong(1) === 1) assert(unsafeRow.getInt(2) === 2) + // We can copy UnsafeRows as long as they don't reference ObjectPools + val unsafeRowCopy = unsafeRow.copy() + assert(unsafeRowCopy.getLong(0) === 0) + assert(unsafeRowCopy.getLong(1) === 1) + assert(unsafeRowCopy.getInt(2) === 2) + unsafeRow.setLong(1, 3) assert(unsafeRow.getLong(1) === 3) unsafeRow.setInt(2, 4) assert(unsafeRow.getInt(2) === 4) + + // Mutating the original row should not have changed the copy + assert(unsafeRowCopy.getLong(0) === 0) + assert(unsafeRowCopy.getLong(1) === 1) + assert(unsafeRowCopy.getInt(2) === 2) } test("basic conversion with primitive, string and binary types") { @@ -73,12 +86,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + val numBytesWritten = converter.writeRow( + row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() val pool = new ObjectPool(10) - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + unsafeRow.pointTo( + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") assert(unsafeRow.get(2) === "World".getBytes) @@ -96,6 +111,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { unsafeRow.update(2, "Hello World".getBytes) assert(unsafeRow.get(2) === "Hello World".getBytes) assert(pool.size === 2) + + // We do not support copy() for UnsafeRows that reference ObjectPools + intercept[UnsupportedOperationException] { + unsafeRow.copy() + } } test("basic conversion with primitive, decimal and array") { @@ -111,12 +131,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(row) assert(sizeRequired === 8 + (8 * 3)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) + val numBytesWritten = + converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, pool) assert(numBytesWritten === sizeRequired) assert(pool.size === 2) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + unsafeRow.pointTo( + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.get(1) === Decimal(1)) assert(unsafeRow.get(2) === Array(2)) @@ -142,11 +164,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(sizeRequired === 8 + (8 * 4) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + val numBytesWritten = + converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.pointTo( + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") // Date is represented as Int in unsafeRow @@ -190,12 +214,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow( - rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, + sizeRequired, null) assert(numBytesWritten === sizeRequired) val createdFromNull = new UnsafeRow() createdFromNull.pointTo( - createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, + sizeRequired, null) for (i <- 0 to fieldTypes.length - 1) { assert(createdFromNull.isNullAt(i)) } @@ -233,10 +259,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val pool = new ObjectPool(1) val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2) converter.writeRow( - rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) + rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, + sizeRequired, pool) val setToNullAfterCreation = new UnsafeRow() setToNullAfterCreation.pointTo( - setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, + sizeRequired, pool) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) From 47ef423f860c3109d50c7e321616b267f4296e34 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 8 Jul 2015 20:29:08 -0700 Subject: [PATCH 067/149] [SPARK-8910] Fix MiMa flaky due to port contention issue Due to the way MiMa works, we currently start a `SQLContext` pretty early on. This causes us to start a `SparkUI` that attempts to bind to port 4040. Because many tests run in parallel on the Jenkins machines, this causes port contention sometimes and fails the MiMa tests. Note that we already disabled the SparkUI for scalatests. However, the MiMa test is run before we even have a chance to load the default scalatest settings, so we need to explicitly disable the UI ourselves. Author: Andrew Or Closes #7300 from andrewor14/mima-flaky and squashes the following commits: b55a547 [Andrew Or] Do not enable SparkUI during tests --- .../scala/org/apache/spark/sql/test/TestSQLContext.scala | 8 ++++---- .../scala/org/apache/spark/sql/hive/test/TestHive.scala | 7 ++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index 9fa394525d65c..b3a4231da91c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -26,10 +26,10 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** A SQLContext that can be used for local testing. */ class LocalSQLContext extends SQLContext( - new SparkContext( - "local[2]", - "TestSQLContext", - new SparkConf().set("spark.sql.testkey", "true"))) { + new SparkContext("local[2]", "TestSQLContext", new SparkConf() + .set("spark.sql.testkey", "true") + // SPARK-8910 + .set("spark.ui.enabled", "false"))) { override protected[sql] def createSession(): SQLSession = { new this.SQLSession() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 7978fdacaedba..0f217bc66869f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -53,9 +53,10 @@ object TestHive "TestSQLContext", new SparkConf() .set("spark.sql.test", "") - .set( - "spark.sql.hive.metastore.barrierPrefixes", - "org.apache.spark.sql.hive.execution.PairSerDe"))) + .set("spark.sql.hive.metastore.barrierPrefixes", + "org.apache.spark.sql.hive.execution.PairSerDe") + // SPARK-8910 + .set("spark.ui.enabled", "false"))) /** * A locally running test instance of Spark's Hive execution engine. From aba5784dab24c03ddad89f7a1b5d3d0dc8d109be Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Thu, 9 Jul 2015 13:28:17 +0900 Subject: [PATCH 068/149] [SPARK-8937] [TEST] A setting `spark.unsafe.exceptionOnMemoryLeak ` is missing in ScalaTest config. `spark.unsafe.exceptionOnMemoryLeak` is present in the config of surefire. ``` org.apache.maven.plugins maven-surefire-plugin 2.18.1 ... true ... ``` but is absent in the config ScalaTest. Author: Kousuke Saruta Closes #7308 from sarutak/add-setting-for-memory-leak and squashes the following commits: 95644e7 [Kousuke Saruta] Added a setting for memory leak --- pom.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/pom.xml b/pom.xml index 9cf2471b51304..529e47f8b5253 100644 --- a/pom.xml +++ b/pom.xml @@ -1339,6 +1339,7 @@ false false true + true From 768907eb7b0d3c11a420ef281454e36167011c89 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 8 Jul 2015 22:05:58 -0700 Subject: [PATCH 069/149] [SPARK-8926][SQL] Good errors for ExpectsInputType expressions For example: `cannot resolve 'testfunction(null)' due to data type mismatch: argument 1 is expected to be of type int, however, null is of type datetype.` Author: Michael Armbrust Closes #7303 from marmbrus/expectsTypeErrors and squashes the following commits: c654a0e [Michael Armbrust] fix udts and make errors pretty 137160d [Michael Armbrust] style 5428fda [Michael Armbrust] style 10fac82 [Michael Armbrust] [SPARK-8926][SQL] Good errors for ExpectsInputType expressions --- .../catalyst/analysis/HiveTypeCoercion.scala | 12 +- .../expressions/ExpectsInputTypes.scala | 13 +- .../spark/sql/types/AbstractDataType.scala | 30 +++- .../apache/spark/sql/types/ArrayType.scala | 8 +- .../org/apache/spark/sql/types/DataType.scala | 4 +- .../apache/spark/sql/types/DecimalType.scala | 8 +- .../org/apache/spark/sql/types/MapType.scala | 8 +- .../apache/spark/sql/types/StructType.scala | 8 +- .../spark/sql/types/UserDefinedType.scala | 5 +- .../analysis/AnalysisErrorSuite.scala | 167 ++++++++++++++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 126 ++----------- .../analysis/HiveTypeCoercionSuite.scala | 8 + .../apache/spark/sql/hive/HiveContext.scala | 2 +- 13 files changed, 256 insertions(+), 143 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 5367b7f3308ee..8cb71995eb818 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -702,11 +702,19 @@ object HiveTypeCoercion { @Nullable val ret: Expression = (inType, expectedType) match { // If the expected type is already a parent of the input type, no need to cast. - case _ if expectedType.isParentOf(inType) => e + case _ if expectedType.isSameType(inType) => e // Cast null type (usually from null literals) into target types case (NullType, target) => Cast(e, target.defaultConcreteType) + // If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is + // already a number, leave it as is. + case (_: NumericType, NumericType) => e + + // If the function accepts any numeric type and the input is a string, we follow the hive + // convention and cast that input into a double + case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType) + // Implicit cast among numeric types // If input is a numeric type but not decimal, and we expect a decimal type, // cast the input to unlimited precision decimal. @@ -732,7 +740,7 @@ object HiveTypeCoercion { // First see if we can find our input type in the type collection. If we can, then just // use the current expression; otherwise, find the first one we can implicitly cast. case (_, TypeCollection(types)) => - if (types.exists(_.isParentOf(inType))) { + if (types.exists(_.isSameType(inType))) { e } else { types.flatMap(implicitCast(e, _)).headOption.orNull diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 916e30154d4f1..986cc09499d1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -37,7 +37,16 @@ trait ExpectsInputTypes { self: Expression => def inputTypes: Seq[AbstractDataType] override def checkInputDataTypes(): TypeCheckResult = { - // TODO: implement proper type checking. - TypeCheckResult.TypeCheckSuccess + val mismatches = children.zip(inputTypes).zipWithIndex.collect { + case ((child, expected), idx) if !expected.acceptsType(child.dataType) => + s"Argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + + s"however, ${child.prettyString} is of type ${child.dataType.simpleString}." + } + + if (mismatches.isEmpty) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(mismatches.mkString(" ")) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index fb1b47e946214..ad75fa2e31d90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -34,9 +34,16 @@ private[sql] abstract class AbstractDataType { private[sql] def defaultConcreteType: DataType /** - * Returns true if this data type is a parent of the `childCandidate`. + * Returns true if this data type is the same type as `other`. This is different that equality + * as equality will also consider data type parametrization, such as decimal precision. */ - private[sql] def isParentOf(childCandidate: DataType): Boolean + private[sql] def isSameType(other: DataType): Boolean + + /** + * Returns true if `other` is an acceptable input type for a function that expectes this, + * possibly abstract, DataType. + */ + private[sql] def acceptsType(other: DataType): Boolean = isSameType(other) /** Readable string representation for the type. */ private[sql] def simpleString: String @@ -58,11 +65,14 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) require(types.nonEmpty, s"TypeCollection ($types) cannot be empty") - private[sql] override def defaultConcreteType: DataType = types.head.defaultConcreteType + override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType + + override private[sql] def isSameType(other: DataType): Boolean = false - private[sql] override def isParentOf(childCandidate: DataType): Boolean = false + override private[sql] def acceptsType(other: DataType): Boolean = + types.exists(_.isSameType(other)) - private[sql] override def simpleString: String = { + override private[sql] def simpleString: String = { types.map(_.simpleString).mkString("(", " or ", ")") } } @@ -108,7 +118,7 @@ abstract class NumericType extends AtomicType { } -private[sql] object NumericType { +private[sql] object NumericType extends AbstractDataType { /** * Enables matching against NumericType for expressions: * {{{ @@ -117,6 +127,14 @@ private[sql] object NumericType { * }}} */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] + + override private[sql] def defaultConcreteType: DataType = DoubleType + + override private[sql] def simpleString: String = "numeric" + + override private[sql] def isSameType(other: DataType): Boolean = false + + override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 43413ec761e6b..76ca7a84c1d1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -26,13 +26,13 @@ object ArrayType extends AbstractDataType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true) - private[sql] override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) + override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) - private[sql] override def isParentOf(childCandidate: DataType): Boolean = { - childCandidate.isInstanceOf[ArrayType] + override private[sql] def isSameType(other: DataType): Boolean = { + other.isInstanceOf[ArrayType] } - private[sql] override def simpleString: String = "array" + override private[sql] def simpleString: String = "array" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index a4c2da8e05f5d..57718228e490f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -76,9 +76,9 @@ abstract class DataType extends AbstractDataType { */ private[spark] def asNullable: DataType - private[sql] override def defaultConcreteType: DataType = this + override private[sql] def defaultConcreteType: DataType = this - private[sql] override def isParentOf(childCandidate: DataType): Boolean = this == childCandidate + override private[sql] def isSameType(other: DataType): Boolean = this == other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 127b16ff85bed..a1cafeab1704d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -84,13 +84,13 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT /** Extra factory methods and pattern matchers for Decimals */ object DecimalType extends AbstractDataType { - private[sql] override def defaultConcreteType: DataType = Unlimited + override private[sql] def defaultConcreteType: DataType = Unlimited - private[sql] override def isParentOf(childCandidate: DataType): Boolean = { - childCandidate.isInstanceOf[DecimalType] + override private[sql] def isSameType(other: DataType): Boolean = { + other.isInstanceOf[DecimalType] } - private[sql] override def simpleString: String = "decimal" + override private[sql] def simpleString: String = "decimal" val Unlimited: DecimalType = DecimalType(None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 868dea13d971e..ddead10bc2171 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -69,13 +69,13 @@ case class MapType( object MapType extends AbstractDataType { - private[sql] override def defaultConcreteType: DataType = apply(NullType, NullType) + override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType) - private[sql] override def isParentOf(childCandidate: DataType): Boolean = { - childCandidate.isInstanceOf[MapType] + override private[sql] def isSameType(other: DataType): Boolean = { + other.isInstanceOf[MapType] } - private[sql] override def simpleString: String = "map" + override private[sql] def simpleString: String = "map" /** * Construct a [[MapType]] object with the given key type and value type. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index e2d3f53f7d978..e0b8ff91786a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -303,13 +303,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru object StructType extends AbstractDataType { - private[sql] override def defaultConcreteType: DataType = new StructType + override private[sql] def defaultConcreteType: DataType = new StructType - private[sql] override def isParentOf(childCandidate: DataType): Boolean = { - childCandidate.isInstanceOf[StructType] + override private[sql] def isSameType(other: DataType): Boolean = { + other.isInstanceOf[StructType] } - private[sql] override def simpleString: String = "struct" + override private[sql] def simpleString: String = "struct" private[sql] def fromString(raw: String): StructType = DataType.fromString(raw) match { case t: StructType => t diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 6b20505c6009a..e47cfb4833bd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -77,5 +77,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { * For UDT, asNullable will not change the nullability of its internal sqlType and just returns * itself. */ - private[spark] override def asNullable: UserDefinedType[UserType] = this + override private[spark] def asNullable: UserDefinedType[UserType] = this + + override private[sql] def acceptsType(dataType: DataType) = + this.getClass == dataType.getClass } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala new file mode 100644 index 0000000000000..73236c3acbca2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.{InternalRow, SimpleCatalystConf} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ + +case class TestFunction( + children: Seq[Expression], + inputTypes: Seq[AbstractDataType]) extends Expression with ExpectsInputTypes { + override def nullable: Boolean = true + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + override def dataType: DataType = StringType +} + +case class UnresolvedTestPlan() extends LeafNode { + override lazy val resolved = false + override def output: Seq[Attribute] = Nil +} + +class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { + import AnalysisSuite._ + + def errorTest( + name: String, + plan: LogicalPlan, + errorMessages: Seq[String], + caseSensitive: Boolean = true): Unit = { + test(name) { + val error = intercept[AnalysisException] { + if (caseSensitive) { + caseSensitiveAnalyze(plan) + } else { + caseInsensitiveAnalyze(plan) + } + } + + errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase)) + } + } + + val dateLit = Literal.create(null, DateType) + + errorTest( + "single invalid type, single arg", + testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), + "cannot resolve" :: "testfunction" :: "argument 1" :: "expected to be of type int" :: + "null is of type date" ::Nil) + + errorTest( + "single invalid type, second arg", + testRelation.select( + TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)), + "cannot resolve" :: "testfunction" :: "argument 2" :: "expected to be of type int" :: + "null is of type date" ::Nil) + + errorTest( + "multiple invalid type", + testRelation.select( + TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)), + "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" :: + "expected to be of type int" :: "null is of type date" ::Nil) + + errorTest( + "unresolved window function", + testRelation2.select( + WindowExpression( + UnresolvedWindowFunction( + "lead", + UnresolvedAttribute("c") :: Nil), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + UnspecifiedFrame)).as('window)), + "lead" :: "window functions currently requires a HiveContext" :: Nil) + + errorTest( + "too many generators", + listRelation.select(Explode('list).as('a), Explode('list).as('b)), + "only one generator" :: "explode" :: Nil) + + errorTest( + "unresolved attributes", + testRelation.select('abcd), + "cannot resolve" :: "abcd" :: Nil) + + errorTest( + "bad casts", + testRelation.select(Literal(1).cast(BinaryType).as('badCast)), + "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) + + errorTest( + "non-boolean filters", + testRelation.where(Literal(1)), + "filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil) + + errorTest( + "missing group by", + testRelation2.groupBy('a)('b), + "'b'" :: "group by" :: Nil + ) + + errorTest( + "ambiguous field", + nestedRelation.select($"top.duplicateField"), + "Ambiguous reference to fields" :: "duplicateField" :: Nil, + caseSensitive = false) + + errorTest( + "ambiguous field due to case insensitivity", + nestedRelation.select($"top.differentCase"), + "Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil, + caseSensitive = false) + + errorTest( + "missing field", + nestedRelation2.select($"top.c"), + "No such struct field" :: "aField" :: "bField" :: "cField" :: Nil, + caseSensitive = false) + + errorTest( + "catch all unresolved plan", + UnresolvedTestPlan(), + "unresolved" :: Nil) + + + test("SPARK-6452 regression test") { + // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) + val plan = + Aggregate( + Nil, + Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil, + LocalRelation( + AttributeReference("a", IntegerType)(exprId = ExprId(2)))) + + assert(plan.resolved) + + val message = intercept[AnalysisException] { + caseSensitiveAnalyze(plan) + }.getMessage + + assert(message.contains("resolved attribute(s) a#1 missing from a#2")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 77ca080f366cd..58df1de983a09 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { +object AnalysisSuite { val caseSensitiveConf = new SimpleCatalystConf(true) val caseInsensitiveConf = new SimpleCatalystConf(false) @@ -61,25 +61,28 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { val nestedRelation = LocalRelation( AttributeReference("top", StructType( StructField("duplicateField", StringType) :: - StructField("duplicateField", StringType) :: - StructField("differentCase", StringType) :: - StructField("differentcase", StringType) :: Nil + StructField("duplicateField", StringType) :: + StructField("differentCase", StringType) :: + StructField("differentcase", StringType) :: Nil ))()) val nestedRelation2 = LocalRelation( AttributeReference("top", StructType( StructField("aField", StringType) :: - StructField("bField", StringType) :: - StructField("cField", StringType) :: Nil + StructField("bField", StringType) :: + StructField("cField", StringType) :: Nil ))()) val listRelation = LocalRelation( AttributeReference("list", ArrayType(IntegerType))()) - before { - caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) - caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) - } + caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) +} + + +class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { + import AnalysisSuite._ test("union project *") { val plan = (1 to 100) @@ -149,91 +152,6 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) } - def errorTest( - name: String, - plan: LogicalPlan, - errorMessages: Seq[String], - caseSensitive: Boolean = true): Unit = { - test(name) { - val error = intercept[AnalysisException] { - if (caseSensitive) { - caseSensitiveAnalyze(plan) - } else { - caseInsensitiveAnalyze(plan) - } - } - - errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase)) - } - } - - errorTest( - "unresolved window function", - testRelation2.select( - WindowExpression( - UnresolvedWindowFunction( - "lead", - UnresolvedAttribute("c") :: Nil), - WindowSpecDefinition( - UnresolvedAttribute("a") :: Nil, - SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, - UnspecifiedFrame)).as('window)), - "lead" :: "window functions currently requires a HiveContext" :: Nil) - - errorTest( - "too many generators", - listRelation.select(Explode('list).as('a), Explode('list).as('b)), - "only one generator" :: "explode" :: Nil) - - errorTest( - "unresolved attributes", - testRelation.select('abcd), - "cannot resolve" :: "abcd" :: Nil) - - errorTest( - "bad casts", - testRelation.select(Literal(1).cast(BinaryType).as('badCast)), - "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) - - errorTest( - "non-boolean filters", - testRelation.where(Literal(1)), - "filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil) - - errorTest( - "missing group by", - testRelation2.groupBy('a)('b), - "'b'" :: "group by" :: Nil - ) - - errorTest( - "ambiguous field", - nestedRelation.select($"top.duplicateField"), - "Ambiguous reference to fields" :: "duplicateField" :: Nil, - caseSensitive = false) - - errorTest( - "ambiguous field due to case insensitivity", - nestedRelation.select($"top.differentCase"), - "Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil, - caseSensitive = false) - - errorTest( - "missing field", - nestedRelation2.select($"top.c"), - "No such struct field" :: "aField" :: "bField" :: "cField" :: Nil, - caseSensitive = false) - - case class UnresolvedTestPlan() extends LeafNode { - override lazy val resolved = false - override def output: Seq[Attribute] = Nil - } - - errorTest( - "catch all unresolved plan", - UnresolvedTestPlan(), - "unresolved" :: Nil) - test("divide should be casted into fractional types") { val testRelation2 = LocalRelation( @@ -258,22 +176,4 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { assert(pl(3).dataType == DecimalType.Unlimited) assert(pl(4).dataType == DoubleType) } - - test("SPARK-6452 regression test") { - // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) - val plan = - Aggregate( - Nil, - Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil, - LocalRelation( - AttributeReference("a", IntegerType)(exprId = ExprId(2)))) - - assert(plan.resolved) - - val message = intercept[AnalysisException] { - caseSensitiveAnalyze(plan) - }.getMessage - - assert(message.contains("resolved attribute(s) a#1 missing from a#2")) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 93db33d44eb25..6e3aa0eebeb15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -77,6 +77,14 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2)) shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2)) shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2)) + + shouldCast(StringType, NumericType, DoubleType) + + // NumericType should not be changed when function accepts any of them. + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, + DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe => + shouldCast(tpe, NumericType, tpe) + } } test("ineligible implicit type cast") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 439d8cab5f257..bbc39b892b79e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -359,7 +359,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { hiveconf.set(key, value) } - private[sql] override def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { + override private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { setConf(entry.key, entry.stringConverter(value)) } From a240bf3b44b15d0da5182d6ebec281dbdc5439e8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 8 Jul 2015 22:08:50 -0700 Subject: [PATCH 070/149] Closes #7310. From 3dab0da42940a46f0c4aa4853bdb5c64c4cb2613 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 8 Jul 2015 22:09:12 -0700 Subject: [PATCH 071/149] [SPARK-8928] [SQL] Makes CatalystSchemaConverter sticking to 1.4.x- when handling Parquet LISTs in compatible mode This PR is based on #7209 authored by Sephiroth-Lin. Author: Weizhong Lin Closes #7304 from liancheng/spark-8928 and squashes the following commits: 75267fe [Cheng Lian] Makes CatalystSchemaConverter sticking to 1.4.x- when handling LISTs in compatible mode --- .../spark/sql/parquet/CatalystSchemaConverter.scala | 6 ++++-- .../apache/spark/sql/parquet/ParquetSchemaSuite.scala | 10 +++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index de3a72d8146c5..1ea6926af6d5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -461,7 +461,8 @@ private[parquet] class CatalystSchemaConverter( field.name, Types .buildGroup(REPEATED) - .addField(convertField(StructField("element", elementType, nullable))) + // "array_element" is the name chosen by parquet-hive (1.7.0 and prior version) + .addField(convertField(StructField("array_element", elementType, nullable))) .named(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME)) // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level @@ -474,7 +475,8 @@ private[parquet] class CatalystSchemaConverter( ConversionPatterns.listType( repetition, field.name, - convertField(StructField("element", elementType, nullable), REPEATED)) + // "array" is the name chosen by parquet-avro (1.7.0 and prior version) + convertField(StructField("array", elementType, nullable), REPEATED)) // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 35d3c33f99a06..fa629392674bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -174,7 +174,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { """ |message root { | optional group _1 (LIST) { - | repeated int32 element; + | repeated int32 array; | } |} """.stripMargin) @@ -198,7 +198,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { |message root { | optional group _1 (LIST) { | repeated group bag { - | optional int32 element; + | optional int32 array_element; | } | } |} @@ -267,7 +267,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | optional binary _1 (UTF8); | optional group _2 (LIST) { | repeated group bag { - | optional group element { + | optional group array_element { | required int32 _1; | required double _2; | } @@ -616,7 +616,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional group f1 (LIST) { | repeated group bag { - | optional int32 element; + | optional int32 array_element; | } | } |} @@ -648,7 +648,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { nullable = true))), """message root { | optional group f1 (LIST) { - | repeated int32 element; + | repeated int32 array; | } |} """.stripMargin) From c056484c0741e2a03d4a916538e1b9e3e65e71c3 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 8 Jul 2015 22:14:38 -0700 Subject: [PATCH 072/149] Revert "[SPARK-8928] [SQL] Makes CatalystSchemaConverter sticking to 1.4.x- when handling Parquet LISTs in compatible mode" This reverts commit 3dab0da42940a46f0c4aa4853bdb5c64c4cb2613. --- .../spark/sql/parquet/CatalystSchemaConverter.scala | 6 ++---- .../apache/spark/sql/parquet/ParquetSchemaSuite.scala | 10 +++++----- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index 1ea6926af6d5b..de3a72d8146c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -461,8 +461,7 @@ private[parquet] class CatalystSchemaConverter( field.name, Types .buildGroup(REPEATED) - // "array_element" is the name chosen by parquet-hive (1.7.0 and prior version) - .addField(convertField(StructField("array_element", elementType, nullable))) + .addField(convertField(StructField("element", elementType, nullable))) .named(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME)) // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level @@ -475,8 +474,7 @@ private[parquet] class CatalystSchemaConverter( ConversionPatterns.listType( repetition, field.name, - // "array" is the name chosen by parquet-avro (1.7.0 and prior version) - convertField(StructField("array", elementType, nullable), REPEATED)) + convertField(StructField("element", elementType, nullable), REPEATED)) // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index fa629392674bd..35d3c33f99a06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -174,7 +174,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { """ |message root { | optional group _1 (LIST) { - | repeated int32 array; + | repeated int32 element; | } |} """.stripMargin) @@ -198,7 +198,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { |message root { | optional group _1 (LIST) { | repeated group bag { - | optional int32 array_element; + | optional int32 element; | } | } |} @@ -267,7 +267,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | optional binary _1 (UTF8); | optional group _2 (LIST) { | repeated group bag { - | optional group array_element { + | optional group element { | required int32 _1; | required double _2; | } @@ -616,7 +616,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional group f1 (LIST) { | repeated group bag { - | optional int32 array_element; + | optional int32 element; | } | } |} @@ -648,7 +648,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { nullable = true))), """message root { | optional group f1 (LIST) { - | repeated int32 array; + | repeated int32 element; | } |} """.stripMargin) From 851e247caad0977cfd4998254d9602624e06539f Mon Sep 17 00:00:00 2001 From: Weizhong Lin Date: Wed, 8 Jul 2015 22:18:39 -0700 Subject: [PATCH 073/149] [SPARK-8928] [SQL] Makes CatalystSchemaConverter sticking to 1.4.x- when handling Parquet LISTs in compatible mode This PR is based on #7209 authored by Sephiroth-Lin. Author: Weizhong Lin Closes #7314 from liancheng/spark-8928 and squashes the following commits: 75267fe [Cheng Lian] Makes CatalystSchemaConverter sticking to 1.4.x- when handling LISTs in compatible mode --- .../spark/sql/parquet/CatalystSchemaConverter.scala | 6 ++++-- .../apache/spark/sql/parquet/ParquetSchemaSuite.scala | 10 +++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index de3a72d8146c5..1ea6926af6d5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -461,7 +461,8 @@ private[parquet] class CatalystSchemaConverter( field.name, Types .buildGroup(REPEATED) - .addField(convertField(StructField("element", elementType, nullable))) + // "array_element" is the name chosen by parquet-hive (1.7.0 and prior version) + .addField(convertField(StructField("array_element", elementType, nullable))) .named(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME)) // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level @@ -474,7 +475,8 @@ private[parquet] class CatalystSchemaConverter( ConversionPatterns.listType( repetition, field.name, - convertField(StructField("element", elementType, nullable), REPEATED)) + // "array" is the name chosen by parquet-avro (1.7.0 and prior version) + convertField(StructField("array", elementType, nullable), REPEATED)) // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 35d3c33f99a06..fa629392674bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -174,7 +174,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { """ |message root { | optional group _1 (LIST) { - | repeated int32 element; + | repeated int32 array; | } |} """.stripMargin) @@ -198,7 +198,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { |message root { | optional group _1 (LIST) { | repeated group bag { - | optional int32 element; + | optional int32 array_element; | } | } |} @@ -267,7 +267,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | optional binary _1 (UTF8); | optional group _2 (LIST) { | repeated group bag { - | optional group element { + | optional group array_element { | required int32 _1; | required double _2; | } @@ -616,7 +616,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional group f1 (LIST) { | repeated group bag { - | optional int32 element; + | optional int32 array_element; | } | } |} @@ -648,7 +648,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { nullable = true))), """message root { | optional group f1 (LIST) { - | repeated int32 element; + | repeated int32 array; | } |} """.stripMargin) From 09cb0d9c2dcb83818ced22ff9bd6a51688ea7ffe Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 9 Jul 2015 00:26:25 -0700 Subject: [PATCH 074/149] [SPARK-8942][SQL] use double not decimal when cast double and float to timestamp Author: Wenchen Fan Closes #7312 from cloud-fan/minor and squashes the following commits: a4589fa [Wenchen Fan] use double not decimal when cast double and float to timestamp --- .../spark/sql/catalyst/expressions/Cast.scala | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 567feca7136f9..7f2383dedc035 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -192,23 +192,18 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Decimal](_, d => decimalToTimestamp(d)) // TimestampWritable.doubleToTimestamp case DoubleType => - buildCast[Double](_, d => try { - decimalToTimestamp(Decimal(d)) - } catch { - case _: NumberFormatException => null - }) + buildCast[Double](_, d => doubleToTimestamp(d)) // TimestampWritable.floatToTimestamp case FloatType => - buildCast[Float](_, f => try { - decimalToTimestamp(Decimal(f)) - } catch { - case _: NumberFormatException => null - }) + buildCast[Float](_, f => doubleToTimestamp(f.toDouble)) } private[this] def decimalToTimestamp(d: Decimal): Long = { (d.toBigDecimal * 1000000L).longValue() } + private[this] def doubleToTimestamp(d: Double): Any = { + if (d.isNaN || d.isInfinite) null else (d * 1000000L).toLong + } // converting milliseconds to us private[this] def longToTimestamp(t: Long): Long = t * 1000L @@ -396,8 +391,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[InternalRow](_, row => { var i = 0 while (i < row.length) { - val v = row(i) - newRow.update(i, if (v == null) null else casts(i)(v)) + newRow.update(i, if (row.isNullAt(i)) null else casts(i)(row(i))) i += 1 } newRow.copy() From f88b12537ee81d914ef7c51a08f80cb28d93c8ed Mon Sep 17 00:00:00 2001 From: lewuathe Date: Thu, 9 Jul 2015 08:16:26 -0700 Subject: [PATCH 075/149] [SPARK-6266] [MLLIB] PySpark SparseVector missing doc for size, indices, values Write missing pydocs in `SparseVector` attributes. Author: lewuathe Closes #7290 from Lewuathe/SPARK-6266 and squashes the following commits: 51d9895 [lewuathe] Update docs 0480d35 [lewuathe] Merge branch 'master' into SPARK-6266 ba42cf3 [lewuathe] [SPARK-6266] PySpark SparseVector missing doc for size, indices, values --- python/pyspark/mllib/linalg.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 51ac198305711..040886f71775b 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -445,8 +445,10 @@ def __init__(self, size, *args): values (sorted by index). :param size: Size of the vector. - :param args: Non-zero entries, as a dictionary, list of tupes, - or two sorted lists containing indices and values. + :param args: Active entries, as a dictionary {index: value, ...}, + a list of tuples [(index, value), ...], or a list of strictly i + ncreasing indices and a list of corresponding values [index, ...], + [value, ...]. Inactive entries are treated as zeros. >>> SparseVector(4, {1: 1.0, 3: 5.5}) SparseVector(4, {1: 1.0, 3: 5.5}) @@ -456,6 +458,7 @@ def __init__(self, size, *args): SparseVector(4, {1: 1.0, 3: 5.5}) """ self.size = int(size) + """ Size of the vector. """ assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments" if len(args) == 1: pairs = args[0] @@ -463,7 +466,9 @@ def __init__(self, size, *args): pairs = pairs.items() pairs = sorted(pairs) self.indices = np.array([p[0] for p in pairs], dtype=np.int32) + """ A list of indices corresponding to active entries. """ self.values = np.array([p[1] for p in pairs], dtype=np.float64) + """ A list of values corresponding to active entries. """ else: if isinstance(args[0], bytes): assert isinstance(args[1], bytes), "values should be string too" From 23448a9e988a1b92bd05ee8c6c1a096c83375a12 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 9 Jul 2015 09:20:16 -0700 Subject: [PATCH 076/149] [SPARK-8931] [SQL] Fallback to interpreted evaluation if failed to compile in codegen Exception will not be catched during tests. cc marmbrus rxin Author: Davies Liu Closes #7309 from davies/fallback and squashes the following commits: 969a612 [Davies Liu] throw exception during tests f844f77 [Davies Liu] fallback a3091bc [Davies Liu] Merge branch 'master' of github.com:apache/spark into fallback 364a0d6 [Davies Liu] fallback to interpret mode if failed to compile --- .../spark/sql/execution/SparkPlan.scala | 51 +++++++++++++++++-- .../apache/spark/sql/sources/commands.scala | 13 ++++- 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ca53186383237..4d7d8626a0ecc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -153,12 +153,24 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ buf.toArray.map(converter(_).asInstanceOf[Row]) } + private[this] def isTesting: Boolean = sys.props.contains("spark.testing") + protected def newProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if (codegenEnabled) { - GenerateProjection.generate(expressions, inputSchema) + try { + GenerateProjection.generate(expressions, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate projection, fallback to interpret", e) + new InterpretedProjection(expressions, inputSchema) + } + } } else { new InterpretedProjection(expressions, inputSchema) } @@ -170,17 +182,36 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ log.debug( s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if(codegenEnabled) { - GenerateMutableProjection.generate(expressions, inputSchema) + try { + GenerateMutableProjection.generate(expressions, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate mutable projection, fallback to interpreted", e) + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } } else { () => new InterpretedMutableProjection(expressions, inputSchema) } } - protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { if (codegenEnabled) { - GeneratePredicate.generate(expression, inputSchema) + try { + GeneratePredicate.generate(expression, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate predicate, fallback to interpreted", e) + InterpretedPredicate.create(expression, inputSchema) + } + } } else { InterpretedPredicate.create(expression, inputSchema) } @@ -190,7 +221,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow] = { if (codegenEnabled) { - GenerateOrdering.generate(order, inputSchema) + try { + GenerateOrdering.generate(order, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate ordering, fallback to interpreted", e) + new RowOrdering(order, inputSchema) + } + } } else { new RowOrdering(order, inputSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index ecbc889770625..9189d176111d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -276,7 +276,18 @@ private[sql] case class InsertIntoHadoopFsRelation( log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if (codegenEnabled) { - GenerateProjection.generate(expressions, inputSchema) + + try { + GenerateProjection.generate(expressions, inputSchema) + } catch { + case e: Exception => + if (sys.props.contains("spark.testing")) { + throw e + } else { + log.error("failed to generate projection, fallback to interpreted", e) + new InterpretedProjection(expressions, inputSchema) + } + } } else { new InterpretedProjection(expressions, inputSchema) } From a1964e9d902bb31f001893da8bc81f6dce08c908 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Thu, 9 Jul 2015 09:22:24 -0700 Subject: [PATCH 077/149] [SPARK-8830] [SQL] native levenshtein distance Jira: https://issues.apache.org/jira/browse/SPARK-8830 rxin and HuJiayin can you have a look on it. Author: Tarek Auel Closes #7236 from tarekauel/native-levenshtein-distance and squashes the following commits: ee4c4de [Tarek Auel] [SPARK-8830] implemented improvement proposals c252e71 [Tarek Auel] [SPARK-8830] removed chartAt; use unsafe method for byte array comparison ddf2222 [Tarek Auel] Merge branch 'master' into native-levenshtein-distance 179920a [Tarek Auel] [SPARK-8830] added description 5e9ed54 [Tarek Auel] [SPARK-8830] removed StringUtils import dce4308 [Tarek Auel] [SPARK-8830] native levenshtein distance --- .../expressions/stringOperations.scala | 9 ++- .../expressions/StringFunctionsSuite.scala | 5 ++ .../apache/spark/unsafe/types/UTF8String.java | 66 ++++++++++++++++++- .../spark/unsafe/types/UTF8StringSuite.java | 24 +++++++ 4 files changed, 97 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 47fc7cdaa826c..57f436485becf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -284,13 +284,12 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres override def dataType: DataType = IntegerType - protected override def nullSafeEval(input1: Any, input2: Any): Any = - StringUtils.getLevenshteinDistance(input1.toString, input2.toString) + protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any = + leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String]) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val stringUtils = classOf[StringUtils].getName - defineCodeGen(ctx, ev, (left, right) => - s"$stringUtils.getLevenshteinDistance($left.toString(), $right.toString())") + nullSafeCodeGen(ctx, ev, (left, right) => + s"${ev.primitive} = $left.levenshteinDistance($right);") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index 1efbe1a245e83..69bef1c63e9dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -282,5 +282,10 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Levenshtein(Literal("abc"), Literal("abc")), 0) checkEvaluation(Levenshtein(Literal("kitten"), Literal("sitting")), 3) checkEvaluation(Levenshtein(Literal("frog"), Literal("fog")), 1) + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + checkEvaluation(Levenshtein(Literal("千世"), Literal("fog")), 3) + checkEvaluation(Levenshtein(Literal("世界千世"), Literal("大a界b")), 4) + // scalastyle:on } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index d2a25096a5e7a..847d80ad583f6 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -99,8 +99,6 @@ public int numBytes() { /** * Returns the number of code points in it. - * - * This is only used by Substring() when `start` is negative. */ public int numChars() { int len = 0; @@ -254,6 +252,70 @@ public boolean equals(final Object other) { } } + /** + * Levenshtein distance is a metric for measuring the distance of two strings. The distance is + * defined by the minimum number of single-character edits (i.e. insertions, deletions or + * substitutions) that are required to change one of the strings into the other. + */ + public int levenshteinDistance(UTF8String other) { + // Implementation adopted from org.apache.common.lang3.StringUtils.getLevenshteinDistance + + int n = numChars(); + int m = other.numChars(); + + if (n == 0) { + return m; + } else if (m == 0) { + return n; + } + + UTF8String s, t; + + if (n <= m) { + s = this; + t = other; + } else { + s = other; + t = this; + int swap; + swap = n; + n = m; + m = swap; + } + + int p[] = new int[n + 1]; + int d[] = new int[n + 1]; + int swap[]; + + int i, i_bytes, j, j_bytes, num_bytes_j, cost; + + for (i = 0; i <= n; i++) { + p[i] = i; + } + + for (j = 0, j_bytes = 0; j < m; j_bytes += num_bytes_j, j++) { + num_bytes_j = numBytesForFirstByte(t.getByte(j_bytes)); + d[0] = j + 1; + + for (i = 0, i_bytes = 0; i < n; i_bytes += numBytesForFirstByte(s.getByte(i_bytes)), i++) { + if (s.getByte(i_bytes) != t.getByte(j_bytes) || + num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) { + cost = 1; + } else { + cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base, + s.offset + i_bytes, num_bytes_j)) ? 0 : 1; + } + d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost); + } + + swap = p; + p = d; + d = swap; + } + + return p[n]; + } + @Override public int hashCode() { int result = 1; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 8ec69ebac8b37..fb463ba17f50b 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -128,4 +128,28 @@ public void substring() { assertEquals(fromString("数据砖头").substring(3, 5), fromString("头")); assertEquals(fromString("ߵ梷").substring(0, 2), fromString("ߵ梷")); } + + @Test + public void levenshteinDistance() { + assertEquals( + UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("")), 0); + assertEquals( + UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("a")), 1); + assertEquals( + UTF8String.fromString("aaapppp").levenshteinDistance(UTF8String.fromString("")), 7); + assertEquals( + UTF8String.fromString("frog").levenshteinDistance(UTF8String.fromString("fog")), 1); + assertEquals( + UTF8String.fromString("fly").levenshteinDistance(UTF8String.fromString("ant")),3); + assertEquals( + UTF8String.fromString("elephant").levenshteinDistance(UTF8String.fromString("hippo")), 7); + assertEquals( + UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("elephant")), 7); + assertEquals( + UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("zzzzzzzz")), 8); + assertEquals( + UTF8String.fromString("hello").levenshteinDistance(UTF8String.fromString("hallo")),1); + assertEquals( + UTF8String.fromString("世界千世").levenshteinDistance(UTF8String.fromString("千a世b")),4); + } } From 59cc38944fe5c1dffc6551775bd939e2ac66c65e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 9 Jul 2015 09:57:12 -0700 Subject: [PATCH 078/149] [SPARK-8940] [SPARKR] Don't overwrite given schema in createDataFrame JIRA: https://issues.apache.org/jira/browse/SPARK-8940 The given `schema` parameter will be overwritten in `createDataFrame` now. If it is not null, we shouldn't overwrite it. Author: Liang-Chi Hsieh Closes #7311 from viirya/df_not_overwrite_schema and squashes the following commits: 2385139 [Liang-Chi Hsieh] Don't overwrite given schema if it is not null. --- R/pkg/R/SQLContext.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 9a743a3411533..30978bb50d339 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -86,7 +86,9 @@ infer_type <- function(x) { createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { if (is.data.frame(data)) { # get the names of columns, they will be put into RDD - schema <- names(data) + if (is.null(schema)) { + schema <- names(data) + } n <- nrow(data) m <- ncol(data) # get rid of factor type From e204d22bb70f28b1cc090ab60f12078479be4ae0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 9 Jul 2015 10:01:01 -0700 Subject: [PATCH 079/149] [SPARK-8948][SQL] Remove ExtractValueWithOrdinal abstract class Also added more documentation for the file. Author: Reynold Xin Closes #7316 from rxin/extract-value and squashes the following commits: 069cb7e [Reynold Xin] Removed ExtractValueWithOrdinal. 621b705 [Reynold Xin] Reverted a line. 11ebd6c [Reynold Xin] [Minor][SQL] Improve documentation for complex type extractors. --- ...alue.scala => complexTypeExtractors.scala} | 54 ++++++++++++------- 1 file changed, 34 insertions(+), 20 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{ExtractValue.scala => complexTypeExtractors.scala} (86%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala similarity index 86% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 2b25ba03579ec..73cc930c45832 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -25,6 +25,11 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines all the expressions to extract values out of complex types. +// For example, getting a field out of an array, map, or struct. +//////////////////////////////////////////////////////////////////////////////////////////////////// + object ExtractValue { /** @@ -73,11 +78,10 @@ object ExtractValue { } } - def unapply(g: ExtractValue): Option[(Expression, Expression)] = { - g match { - case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal)) - case s: ExtractValueWithStruct => Some((s.child, null)) - } + def unapply(g: ExtractValue): Option[(Expression, Expression)] = g match { + case o: GetArrayItem => Some((o.child, o.ordinal)) + case o: GetMapValue => Some((o.child, o.key)) + case s: ExtractValueWithStruct => Some((s.child, null)) } /** @@ -117,6 +121,8 @@ abstract class ExtractValueWithStruct extends UnaryExpression with ExtractValue /** * Returns the value of fields in the Struct `child`. + * + * No need to do type checking since it is handled by [[ExtractValue]]. */ case class GetStructField(child: Expression, field: StructField, ordinal: Int) extends ExtractValueWithStruct { @@ -142,6 +148,8 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) /** * Returns the array of value of fields in the Array of Struct `child`. + * + * No need to do type checking since it is handled by [[ExtractValue]]. */ case class GetArrayStructFields( child: Expression, @@ -178,25 +186,21 @@ case class GetArrayStructFields( } } -abstract class ExtractValueWithOrdinal extends BinaryExpression with ExtractValue { - self: Product => +/** + * Returns the field at `ordinal` in the Array `child`. + * + * No need to do type checking since it is handled by [[ExtractValue]]. + */ +case class GetArrayItem(child: Expression, ordinal: Expression) + extends BinaryExpression with ExtractValue { - def ordinal: Expression - def child: Expression + override def toString: String = s"$child[$ordinal]" override def left: Expression = child override def right: Expression = ordinal /** `Null` is returned for invalid ordinals. */ override def nullable: Boolean = true - override def toString: String = s"$child[$ordinal]" -} - -/** - * Returns the field at `ordinal` in the Array `child` - */ -case class GetArrayItem(child: Expression, ordinal: Expression) - extends ExtractValueWithOrdinal { override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType @@ -227,10 +231,20 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } /** - * Returns the value of key `ordinal` in Map `child` + * Returns the value of key `ordinal` in Map `child`. + * + * No need to do type checking since it is handled by [[ExtractValue]]. */ -case class GetMapValue(child: Expression, ordinal: Expression) - extends ExtractValueWithOrdinal { +case class GetMapValue(child: Expression, key: Expression) + extends BinaryExpression with ExtractValue { + + override def toString: String = s"$child[$key]" + + override def left: Expression = child + override def right: Expression = key + + /** `Null` is returned for invalid ordinals. */ + override def nullable: Boolean = true override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType From a870a82fb6f57bb63bd6f1e95da944a30f67519a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 9 Jul 2015 10:01:33 -0700 Subject: [PATCH 080/149] [SPARK-8926][SQL] Code review followup. I merged https://github.com/apache/spark/pull/7303 so it unblocks another PR. This addresses my own code review comment for that PR. Author: Reynold Xin Closes #7313 from rxin/adt and squashes the following commits: 7ade82b [Reynold Xin] Fixed unit tests. f8d5533 [Reynold Xin] [SPARK-8926][SQL] Code review followup. --- .../catalyst/expressions/ExpectsInputTypes.scala | 4 ++-- .../spark/sql/types/AbstractDataType.scala | 16 ++++++++++++++++ .../catalyst/analysis/AnalysisErrorSuite.scala | 8 ++++---- .../analysis/HiveTypeCoercionSuite.scala | 1 + 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 986cc09499d1f..3eb0eb195c80d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -39,8 +39,8 @@ trait ExpectsInputTypes { self: Expression => override def checkInputDataTypes(): TypeCheckResult = { val mismatches = children.zip(inputTypes).zipWithIndex.collect { case ((child, expected), idx) if !expected.acceptsType(child.dataType) => - s"Argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + - s"however, ${child.prettyString} is of type ${child.dataType.simpleString}." + s"argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + + s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}." } if (mismatches.isEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index ad75fa2e31d90..32f87440b4e37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -36,12 +36,28 @@ private[sql] abstract class AbstractDataType { /** * Returns true if this data type is the same type as `other`. This is different that equality * as equality will also consider data type parametrization, such as decimal precision. + * + * {{{ + * // this should return true + * DecimalType.isSameType(DecimalType(10, 2)) + * + * // this should return false + * NumericType.isSameType(DecimalType(10, 2)) + * }}} */ private[sql] def isSameType(other: DataType): Boolean /** * Returns true if `other` is an acceptable input type for a function that expectes this, * possibly abstract, DataType. + * + * {{{ + * // this should return true + * DecimalType.isSameType(DecimalType(10, 2)) + * + * // this should return true as well + * NumericType.acceptsType(DecimalType(10, 2)) + * }}} */ private[sql] def acceptsType(other: DataType): Boolean = isSameType(other) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 73236c3acbca2..9d0c69a2451d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -58,7 +58,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { } } - errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase)) + errorMessages.foreach(m => assert(error.getMessage.toLowerCase.contains(m.toLowerCase))) } } @@ -68,21 +68,21 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { "single invalid type, single arg", testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 1" :: "expected to be of type int" :: - "null is of type date" ::Nil) + "'null' is of type date" ::Nil) errorTest( "single invalid type, second arg", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 2" :: "expected to be of type int" :: - "null is of type date" ::Nil) + "'null' is of type date" ::Nil) errorTest( "multiple invalid type", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" :: - "expected to be of type int" :: "null is of type date" ::Nil) + "expected to be of type int" :: "'null' is of type date" ::Nil) errorTest( "unresolved window function", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 6e3aa0eebeb15..acb9a433de903 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -79,6 +79,7 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2)) shouldCast(StringType, NumericType, DoubleType) + shouldCast(StringType, TypeCollection(NumericType, BinaryType), DoubleType) // NumericType should not be changed when function accepts any of them. Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, From f6c0bd5c3755b2f9bab633a5d478240fdaf1c593 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 9 Jul 2015 10:04:42 -0700 Subject: [PATCH 081/149] [SPARK-8938][SQL] Implement toString for Interval data type Author: Wenchen Fan Closes #7315 from cloud-fan/toString and squashes the following commits: 4fc8d80 [Wenchen Fan] Implement toString for Interval data type --- .../apache/spark/sql/catalyst/SqlParser.scala | 24 ++++++-- .../apache/spark/unsafe/types/Interval.java | 42 +++++++++++++ .../spark/unsafe/types/IntervalSuite.java | 59 +++++++++++++++++++ 3 files changed, 119 insertions(+), 6 deletions(-) create mode 100644 unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index dedd8c8fa3620..d4ef04c2294a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -353,22 +353,34 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { integral <~ intervalUnit("microsecond") ^^ { case num => num.toLong } protected lazy val millisecond: Parser[Long] = - integral <~ intervalUnit("millisecond") ^^ { case num => num.toLong * 1000 } + integral <~ intervalUnit("millisecond") ^^ { + case num => num.toLong * Interval.MICROS_PER_MILLI + } protected lazy val second: Parser[Long] = - integral <~ intervalUnit("second") ^^ { case num => num.toLong * 1000 * 1000 } + integral <~ intervalUnit("second") ^^ { + case num => num.toLong * Interval.MICROS_PER_SECOND + } protected lazy val minute: Parser[Long] = - integral <~ intervalUnit("minute") ^^ { case num => num.toLong * 1000 * 1000 * 60 } + integral <~ intervalUnit("minute") ^^ { + case num => num.toLong * Interval.MICROS_PER_MINUTE + } protected lazy val hour: Parser[Long] = - integral <~ intervalUnit("hour") ^^ { case num => num.toLong * 1000 * 1000 * 3600 } + integral <~ intervalUnit("hour") ^^ { + case num => num.toLong * Interval.MICROS_PER_HOUR + } protected lazy val day: Parser[Long] = - integral <~ intervalUnit("day") ^^ { case num => num.toLong * 1000 * 1000 * 3600 * 24 } + integral <~ intervalUnit("day") ^^ { + case num => num.toLong * Interval.MICROS_PER_DAY + } protected lazy val week: Parser[Long] = - integral <~ intervalUnit("week") ^^ { case num => num.toLong * 1000 * 1000 * 3600 * 24 * 7 } + integral <~ intervalUnit("week") ^^ { + case num => num.toLong * Interval.MICROS_PER_WEEK + } protected lazy val intervalLiteral: Parser[Literal] = INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~ diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java index 3eb67ede062d9..0af982d4844c2 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java @@ -23,6 +23,13 @@ * The internal representation of interval type. */ public final class Interval implements Serializable { + public static final long MICROS_PER_MILLI = 1000L; + public static final long MICROS_PER_SECOND = MICROS_PER_MILLI * 1000; + public static final long MICROS_PER_MINUTE = MICROS_PER_SECOND * 60; + public static final long MICROS_PER_HOUR = MICROS_PER_MINUTE * 60; + public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24; + public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7; + public final int months; public final long microseconds; @@ -44,4 +51,39 @@ public boolean equals(Object other) { public int hashCode() { return 31 * months + (int) microseconds; } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("interval"); + + if (months != 0) { + appendUnit(sb, months / 12, "year"); + appendUnit(sb, months % 12, "month"); + } + + if (microseconds != 0) { + long rest = microseconds; + appendUnit(sb, rest / MICROS_PER_WEEK, "week"); + rest %= MICROS_PER_WEEK; + appendUnit(sb, rest / MICROS_PER_DAY, "day"); + rest %= MICROS_PER_DAY; + appendUnit(sb, rest / MICROS_PER_HOUR, "hour"); + rest %= MICROS_PER_HOUR; + appendUnit(sb, rest / MICROS_PER_MINUTE, "minute"); + rest %= MICROS_PER_MINUTE; + appendUnit(sb, rest / MICROS_PER_SECOND, "second"); + rest %= MICROS_PER_SECOND; + appendUnit(sb, rest / MICROS_PER_MILLI, "millisecond"); + rest %= MICROS_PER_MILLI; + appendUnit(sb, rest, "microsecond"); + } + + return sb.toString(); + } + + private void appendUnit(StringBuilder sb, long value, String unit) { + if (value != 0) { + sb.append(" " + value + " " + unit + "s"); + } + } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java new file mode 100644 index 0000000000000..0f4f38b2b03be --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java @@ -0,0 +1,59 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.unsafe.types; + +import org.junit.Test; + +import static junit.framework.Assert.*; +import static org.apache.spark.unsafe.types.Interval.*; + +public class IntervalSuite { + + @Test + public void equalsTest() { + Interval i1 = new Interval(3, 123); + Interval i2 = new Interval(3, 321); + Interval i3 = new Interval(1, 123); + Interval i4 = new Interval(3, 123); + + assertNotSame(i1, i2); + assertNotSame(i1, i3); + assertNotSame(i2, i3); + assertEquals(i1, i4); + } + + @Test + public void toStringTest() { + Interval i; + + i = new Interval(34, 0); + assertEquals(i.toString(), "interval 2 years 10 months"); + + i = new Interval(-34, 0); + assertEquals(i.toString(), "interval -2 years -10 months"); + + i = new Interval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + assertEquals(i.toString(), "interval 3 weeks 13 hours 123 microseconds"); + + i = new Interval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123); + assertEquals(i.toString(), "interval -3 weeks -13 hours -123 microseconds"); + + i = new Interval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds"); + } +} From c59e268d17cf10e46dbdbe760e2a7580a6364692 Mon Sep 17 00:00:00 2001 From: JPark Date: Thu, 9 Jul 2015 10:23:36 -0700 Subject: [PATCH 082/149] [SPARK-8863] [EC2] Check aws access key from aws credentials if there is no boto config 'spark_ec2.py' use boto to control ec2. And boto can support '~/.aws/credentials' which is AWS CLI default configuration file. We can check this information from ref of boto. "A boto config file is a text file formatted like an .ini configuration file that specifies values for options that control the behavior of the boto library. In Unix/Linux systems, on startup, the boto library looks for configuration files in the following locations and in the following order: /etc/boto.cfg - for site-wide settings that all users on this machine will use (if profile is given) ~/.aws/credentials - for credentials shared between SDKs (if profile is given) ~/.boto - for user-specific settings ~/.aws/credentials - for credentials shared between SDKs ~/.boto - for user-specific settings" * ref of boto: http://boto.readthedocs.org/en/latest/boto_config_tut.html * ref of aws cli : http://docs.aws.amazon.com/cli/latest/userguide/cli-chap-getting-started.html However 'spark_ec2.py' only check boto config & environment variable even if there is '~/.aws/credentials', and 'spark_ec2.py' is terminated. So I changed to check '~/.aws/credentials'. cc rxin Jira : https://issues.apache.org/jira/browse/SPARK-8863 Author: JPark Closes #7252 from JuhongPark/master and squashes the following commits: 23c5792 [JPark] Check aws access key from aws credentials if there is no boto config --- ec2/spark_ec2.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index dd0c12d25980b..ae4f2ecc5bde7 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -325,14 +325,16 @@ def parse_args(): home_dir = os.getenv('HOME') if home_dir is None or not os.path.isfile(home_dir + '/.boto'): if not os.path.isfile('/etc/boto.cfg'): - if os.getenv('AWS_ACCESS_KEY_ID') is None: - print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", - file=stderr) - sys.exit(1) - if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", - file=stderr) - sys.exit(1) + # If there is no boto config, check aws credentials + if not os.path.isfile(home_dir + '/.aws/credentials'): + if os.getenv('AWS_ACCESS_KEY_ID') is None: + print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", + file=stderr) + sys.exit(1) + if os.getenv('AWS_SECRET_ACCESS_KEY') is None: + print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", + file=stderr) + sys.exit(1) return (opts, action, cluster_name) From 0cd84c86cac68600a74d84e50ad40c0c8b84822a Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 9 Jul 2015 10:26:38 -0700 Subject: [PATCH 083/149] [SPARK-8703] [ML] Add CountVectorizer as a ml transformer to convert document to words count vector jira: https://issues.apache.org/jira/browse/SPARK-8703 Converts a text document to a sparse vector of token counts. I can further add an estimator to extract vocabulary from corpus if that's appropriate. Author: Yuhao Yang Closes #7084 from hhbyyh/countVectorization and squashes the following commits: 5f3f655 [Yuhao Yang] text change 24728e4 [Yuhao Yang] style improvement 576728a [Yuhao Yang] rename to model and some fix 1deca28 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into countVectorization 99b0c14 [Yuhao Yang] undo extension from HashingTF 12c2dc8 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into countVectorization 7ee1c31 [Yuhao Yang] extends HashingTF 809fb59 [Yuhao Yang] minor fix for ut 7c61fb3 [Yuhao Yang] add countVectorizer --- .../ml/feature/CountVectorizerModel.scala | 82 +++++++++++++++++++ .../ml/feature/CountVectorizorSuite.scala | 73 +++++++++++++++++ 2 files changed, 155 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala new file mode 100644 index 0000000000000..6b77de89a0330 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{ParamMap, ParamValidators, IntParam} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector} +import org.apache.spark.sql.types.{StringType, ArrayType, DataType} + +/** + * :: Experimental :: + * Converts a text document to a sparse vector of token counts. + * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted. + */ +@Experimental +class CountVectorizerModel (override val uid: String, val vocabulary: Array[String]) + extends UnaryTransformer[Seq[String], Vector, CountVectorizerModel] { + + def this(vocabulary: Array[String]) = + this(Identifiable.randomUID("cntVec"), vocabulary) + + /** + * Corpus-specific filter to ignore scarce words in a document. For each document, terms with + * frequency (count) less than the given threshold are ignored. + * Default: 1 + * @group param + */ + val minTermFreq: IntParam = new IntParam(this, "minTermFreq", + "minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " + + "terms with frequency less than the given threshold are ignored.", ParamValidators.gtEq(1)) + + /** @group setParam */ + def setMinTermFreq(value: Int): this.type = set(minTermFreq, value) + + /** @group getParam */ + def getMinTermFreq: Int = $(minTermFreq) + + setDefault(minTermFreq -> 1) + + override protected def createTransformFunc: Seq[String] => Vector = { + val dict = vocabulary.zipWithIndex.toMap + document => + val termCounts = mutable.HashMap.empty[Int, Double] + document.foreach { term => + dict.get(term) match { + case Some(index) => termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0) + case None => // ignore terms not in the vocabulary + } + } + Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermFreq)).toSeq) + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType.sameType(ArrayType(StringType)), + s"Input type must be ArrayType(StringType) but got $inputType.") + } + + override protected def outputDataType: DataType = new VectorUDT() + + override def copy(extra: ParamMap): CountVectorizerModel = { + val copied = new CountVectorizerModel(uid, vocabulary) + copyValues(copied, extra) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala new file mode 100644 index 0000000000000..e90d9d4ef21ff --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) + } + + test("CountVectorizerModel common cases") { + val df = sqlContext.createDataFrame(Seq( + (0, "a b c d".split(" ").toSeq, + Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), + (1, "a b b c d a".split(" ").toSeq, + Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))), + (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq((0, 1.0)))), + (3, "".split(" ").toSeq, Vectors.sparse(4, Seq())), // empty string + (4, "a notInDict d".split(" ").toSeq, + Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) // with words not in vocabulary + )).toDF("id", "words", "expected") + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + val output = cv.transform(df).collect() + output.foreach { p => + val features = p.getAs[Vector]("features") + val expected = p.getAs[Vector]("expected") + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizerModel with minTermFreq") { + val df = sqlContext.createDataFrame(Seq( + (0, "a a a b b c c c d ".split(" ").toSeq, Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), + (1, "c c c c c c".split(" ").toSeq, Vectors.sparse(4, Seq((2, 6.0)))), + (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq())), + (3, "e e e e e".split(" ").toSeq, Vectors.sparse(4, Seq()))) + ).toDF("id", "words", "expected") + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + .setMinTermFreq(3) + val output = cv.transform(df).collect() + output.foreach { p => + val features = p.getAs[Vector]("features") + val expected = p.getAs[Vector]("expected") + assert(features ~== expected absTol 1e-14) + } + } +} + + From 0b0b9ceaf73de472198c9804fb7ae61fa2a2e097 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 9 Jul 2015 11:11:34 -0700 Subject: [PATCH 084/149] [SPARK-8247] [SPARK-8249] [SPARK-8252] [SPARK-8254] [SPARK-8257] [SPARK-8258] [SPARK-8259] [SPARK-8261] [SPARK-8262] [SPARK-8253] [SPARK-8260] [SPARK-8267] [SQL] Add String Expressions Author: Cheng Hao Closes #6762 from chenghao-intel/str_funcs and squashes the following commits: b09a909 [Cheng Hao] update the code as feedback 7ebbf4c [Cheng Hao] Add more string expressions --- .../catalyst/analysis/FunctionRegistry.scala | 12 + .../expressions/stringOperations.scala | 306 ++++++++++++++- .../expressions/StringFunctionsSuite.scala | 138 +++++++ .../org/apache/spark/sql/functions.scala | 353 ++++++++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 132 ++++++- .../apache/spark/unsafe/types/UTF8String.java | 191 ++++++++++ .../spark/unsafe/types/UTF8StringSuite.java | 94 ++++- 7 files changed, 1202 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5c25181e1cf50..f62d79f8cea6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -147,12 +147,24 @@ object FunctionRegistry { expression[Base64]("base64"), expression[Encode]("encode"), expression[Decode]("decode"), + expression[StringInstr]("instr"), expression[Lower]("lcase"), expression[Lower]("lower"), expression[StringLength]("length"), expression[Levenshtein]("levenshtein"), + expression[StringLocate]("locate"), + expression[StringLPad]("lpad"), + expression[StringTrimLeft]("ltrim"), + expression[StringFormat]("printf"), + expression[StringRPad]("rpad"), + expression[StringRepeat]("repeat"), + expression[StringReverse]("reverse"), + expression[StringTrimRight]("rtrim"), + expression[StringSpace]("space"), + expression[StringSplit]("split"), expression[Substring]("substr"), expression[Substring]("substring"), + expression[StringTrim]("trim"), expression[UnBase64]("unbase64"), expression[Upper]("ucase"), expression[Unhex]("unhex"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 57f436485becf..f64899c1ed84c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale import java.util.regex.Pattern import org.apache.commons.lang3.StringUtils @@ -104,7 +105,7 @@ case class RLike(left: Expression, right: Expression) override def toString: String = s"$left RLIKE $right" } -trait CaseConversionExpression extends ExpectsInputTypes { +trait String2StringExpression extends ExpectsInputTypes { self: UnaryExpression => def convert(v: UTF8String): UTF8String @@ -119,7 +120,7 @@ trait CaseConversionExpression extends ExpectsInputTypes { /** * A function that converts the characters of a string to uppercase. */ -case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { +case class Upper(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toUpperCase @@ -131,7 +132,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE /** * A function that converts the characters of a string to lowercase. */ -case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { +case class Lower(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toLowerCase @@ -187,6 +188,301 @@ case class EndsWith(left: Expression, right: Expression) } } +/** + * A function that trim the spaces from both ends for the specified string. + */ +case class StringTrim(child: Expression) + extends UnaryExpression with String2StringExpression { + + def convert(v: UTF8String): UTF8String = v.trim() + + override def prettyName: String = "trim" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).trim()") + } +} + +/** + * A function that trim the spaces from left end for given string. + */ +case class StringTrimLeft(child: Expression) + extends UnaryExpression with String2StringExpression { + + def convert(v: UTF8String): UTF8String = v.trimLeft() + + override def prettyName: String = "ltrim" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).trimLeft()") + } +} + +/** + * A function that trim the spaces from right end for given string. + */ +case class StringTrimRight(child: Expression) + extends UnaryExpression with String2StringExpression { + + def convert(v: UTF8String): UTF8String = v.trimRight() + + override def prettyName: String = "rtrim" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).trimRight()") + } +} + +/** + * A function that returns the position of the first occurrence of substr in the given string. + * Returns null if either of the arguments are null and + * returns 0 if substr could not be found in str. + * + * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. + */ +case class StringInstr(str: Expression, substr: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = str + override def right: Expression = substr + override def dataType: DataType = IntegerType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + override def nullSafeEval(string: Any, sub: Any): Any = { + string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 1 + } + + override def prettyName: String = "instr" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (l, r) => + s"($l).indexOf($r, 0) + 1") + } +} + +/** + * A function that returns the position of the first occurrence of substr + * in given string after position pos. + */ +case class StringLocate(substr: Expression, str: Expression, start: Expression) + extends Expression with ExpectsInputTypes { + + def this(substr: Expression, str: Expression) = { + this(substr, str, Literal(0)) + } + + override def children: Seq[Expression] = substr :: str :: start :: Nil + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = substr.nullable || str.nullable + override def dataType: DataType = IntegerType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + + override def eval(input: InternalRow): Any = { + val s = start.eval(input) + if (s == null) { + // if the start position is null, we need to return 0, (conform to Hive) + 0 + } else { + val r = substr.eval(input) + if (r == null) { + null + } else { + val l = str.eval(input) + if (l == null) { + null + } else { + l.asInstanceOf[UTF8String].indexOf( + r.asInstanceOf[UTF8String], + s.asInstanceOf[Int]) + 1 + } + } + } + } + + override def prettyName: String = "locate" +} + +/** + * Returns str, left-padded with pad to a length of len. + */ +case class StringLPad(str: Expression, len: Expression, pad: Expression) + extends Expression with ExpectsInputTypes { + + override def children: Seq[Expression] = str :: len :: pad :: Nil + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = children.exists(_.nullable) + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) + + override def eval(input: InternalRow): Any = { + val s = str.eval(input) + if (s == null) { + null + } else { + val l = len.eval(input) + if (l == null) { + null + } else { + val p = pad.eval(input) + if (p == null) { + null + } else { + val len = l.asInstanceOf[Int] + val str = s.asInstanceOf[UTF8String] + val pad = p.asInstanceOf[UTF8String] + + str.lpad(len, pad) + } + } + } + } + + override def prettyName: String = "lpad" +} + +/** + * Returns str, right-padded with pad to a length of len. + */ +case class StringRPad(str: Expression, len: Expression, pad: Expression) + extends Expression with ExpectsInputTypes { + + override def children: Seq[Expression] = str :: len :: pad :: Nil + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = children.exists(_.nullable) + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) + + override def eval(input: InternalRow): Any = { + val s = str.eval(input) + if (s == null) { + null + } else { + val l = len.eval(input) + if (l == null) { + null + } else { + val p = pad.eval(input) + if (p == null) { + null + } else { + val len = l.asInstanceOf[Int] + val str = s.asInstanceOf[UTF8String] + val pad = p.asInstanceOf[UTF8String] + + str.rpad(len, pad) + } + } + } + } + + override def prettyName: String = "rpad" +} + +/** + * Returns the input formatted according do printf-style format strings + */ +case class StringFormat(children: Expression*) extends Expression { + + require(children.length >=1, "printf() should take at least 1 argument") + + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = children(0).nullable + override def dataType: DataType = StringType + private def format: Expression = children(0) + private def args: Seq[Expression] = children.tail + + override def eval(input: InternalRow): Any = { + val pattern = format.eval(input) + if (pattern == null) { + null + } else { + val sb = new StringBuffer() + val formatter = new java.util.Formatter(sb, Locale.US) + + val arglist = args.map(_.eval(input).asInstanceOf[AnyRef]) + formatter.format(pattern.asInstanceOf[UTF8String].toString(), arglist: _*) + + UTF8String.fromString(sb.toString) + } + } + + override def prettyName: String = "printf" +} + +/** + * Returns the string which repeat the given string value n times. + */ +case class StringRepeat(str: Expression, times: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = str + override def right: Expression = times + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType) + + override def nullSafeEval(string: Any, n: Any): Any = { + string.asInstanceOf[UTF8String].repeat(n.asInstanceOf[Integer]) + } + + override def prettyName: String = "repeat" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)") + } +} + +/** + * Returns the reversed given string. + */ +case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression { + override def convert(v: UTF8String): UTF8String = v.reverse() + + override def prettyName: String = "reverse" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).reverse()") + } +} + +/** + * Returns a n spaces string. + */ +case class StringSpace(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(IntegerType) + + override def nullSafeEval(s: Any): Any = { + val length = s.asInstanceOf[Integer] + + val spaces = new Array[Byte](if (length < 0) 0 else length) + java.util.Arrays.fill(spaces, ' '.asInstanceOf[Byte]) + UTF8String.fromBytes(spaces) + } + + override def prettyName: String = "space" +} + +/** + * Splits str around pat (pattern is a regular expression). + */ +case class StringSplit(str: Expression, pattern: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = str + override def right: Expression = pattern + override def dataType: DataType = ArrayType(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + override def nullSafeEval(string: Any, regex: Any): Any = { + val splits = + string.asInstanceOf[UTF8String].toString.split(regex.asInstanceOf[UTF8String].toString, -1) + splits.toSeq.map(UTF8String.fromString) + } + + override def prettyName: String = "split" +} + /** * A function that takes a substring of its first argument starting at a given position. * Defined for String and Binary types. @@ -199,8 +495,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } override def foldable: Boolean = str.foldable && pos.foldable && len.foldable - - override def nullable: Boolean = str.nullable || pos.nullable || len.nullable + override def nullable: Boolean = str.nullable || pos.nullable || len.nullable override def dataType: DataType = { if (!resolved) { @@ -373,4 +668,3 @@ case class Encode(value: Expression, charset: Expression) } } - diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index 69bef1c63e9dc..b19f4ee37a109 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -288,4 +288,142 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Levenshtein(Literal("世界千世"), Literal("大a界b")), 4) // scalastyle:on } + + test("TRIM/LTRIM/RTRIM") { + val s = 'a.string.at(0) + checkEvaluation(StringTrim(Literal(" aa ")), "aa", create_row(" abdef ")) + checkEvaluation(StringTrim(s), "abdef", create_row(" abdef ")) + + checkEvaluation(StringTrimLeft(Literal(" aa ")), "aa ", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(s), "abdef ", create_row(" abdef ")) + + checkEvaluation(StringTrimRight(Literal(" aa ")), " aa", create_row(" abdef ")) + checkEvaluation(StringTrimRight(s), " abdef", create_row(" abdef ")) + + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringTrimRight(s), " 花花世界", create_row(" 花花世界 ")) + checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 ")) + checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 ")) + // scalastyle:on + } + + test("FORMAT") { + val f = 'f.string.at(0) + val d1 = 'd.int.at(1) + val s1 = 's.int.at(2) + + val row1 = create_row("aa%d%s", 12, "cc") + val row2 = create_row(null, 12, "cc") + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + + checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1) + checkEvaluation(StringFormat(f, d1, s1), null, row2) + } + + test("INSTR") { + val s1 = 'a.string.at(0) + val s2 = 'b.string.at(1) + val s3 = 'c.string.at(2) + val row1 = create_row("aaads", "aa", "zz") + + checkEvaluation(StringInstr(Literal("aaads"), Literal("aa")), 1, row1) + checkEvaluation(StringInstr(Literal("aaads"), Literal("de")), 0, row1) + checkEvaluation(StringInstr(Literal.create(null, StringType), Literal("de")), null, row1) + checkEvaluation(StringInstr(Literal("aaads"), Literal.create(null, StringType)), null, row1) + + checkEvaluation(StringInstr(s1, s2), 1, row1) + checkEvaluation(StringInstr(s1, s3), 0, row1) + + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringInstr(s1, s2), 3, create_row("花花世界", "世界")) + checkEvaluation(StringInstr(s1, s2), 1, create_row("花花世界", "花")) + checkEvaluation(StringInstr(s1, s2), 0, create_row("花花世界", "小")) + // scalastyle:on + } + + test("LOCATE") { + val s1 = 'a.string.at(0) + val s2 = 'b.string.at(1) + val s3 = 'c.string.at(2) + val s4 = 'd.int.at(3) + val row1 = create_row("aaads", "aa", "zz", 1) + + checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 0, row1) + checkEvaluation(new StringLocate(Literal("de"), Literal("aaads")), 0, row1) + checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 1), 0, row1) + + checkEvaluation(new StringLocate(s2, s1), 1, row1) + checkEvaluation(StringLocate(s2, s1, s4), 2, row1) + checkEvaluation(new StringLocate(s3, s1), 0, row1) + checkEvaluation(StringLocate(s3, s1, Literal.create(null, IntegerType)), 0, row1) + } + + test("LPAD/RPAD") { + val s1 = 'a.string.at(0) + val s2 = 'b.int.at(1) + val s3 = 'c.string.at(2) + val row1 = create_row("hi", 5, "??") + val row2 = create_row("hi", 1, "?") + val row3 = create_row(null, 1, "?") + + checkEvaluation(StringLPad(Literal("hi"), Literal(5), Literal("??")), "???hi", row1) + checkEvaluation(StringLPad(Literal("hi"), Literal(1), Literal("??")), "h", row1) + checkEvaluation(StringLPad(s1, s2, s3), "???hi", row1) + checkEvaluation(StringLPad(s1, s2, s3), "h", row2) + checkEvaluation(StringLPad(s1, s2, s3), null, row3) + + checkEvaluation(StringRPad(Literal("hi"), Literal(5), Literal("??")), "hi???", row1) + checkEvaluation(StringRPad(Literal("hi"), Literal(1), Literal("??")), "h", row1) + checkEvaluation(StringRPad(s1, s2, s3), "hi???", row1) + checkEvaluation(StringRPad(s1, s2, s3), "h", row2) + checkEvaluation(StringRPad(s1, s2, s3), null, row3) + } + + test("REPEAT") { + val s1 = 'a.string.at(0) + val s2 = 'b.int.at(1) + val row1 = create_row("hi", 2) + val row2 = create_row(null, 1) + + checkEvaluation(StringRepeat(Literal("hi"), Literal(2)), "hihi", row1) + checkEvaluation(StringRepeat(Literal("hi"), Literal(-1)), "", row1) + checkEvaluation(StringRepeat(s1, s2), "hihi", row1) + checkEvaluation(StringRepeat(s1, s2), null, row2) + } + + test("REVERSE") { + val s = 'a.string.at(0) + val row1 = create_row("abccc") + checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1) + checkEvaluation(StringReverse(s), "cccba", row1) + } + + test("SPACE") { + val s1 = 'b.int.at(0) + val row1 = create_row(2) + val row2 = create_row(null) + + checkEvaluation(StringSpace(Literal(2)), " ", row1) + checkEvaluation(StringSpace(Literal(-1)), "", row1) + checkEvaluation(StringSpace(Literal(0)), "", row1) + checkEvaluation(StringSpace(s1), " ", row1) + checkEvaluation(StringSpace(s1), null, row2) + } + + test("SPLIT") { + val s1 = 'a.string.at(0) + val s2 = 'b.string.at(1) + val row1 = create_row("aa2bb3cc", "[1-9]+") + + checkEvaluation( + StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1) + checkEvaluation( + StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4da9ffc495e17..08bf37a5c223c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1626,6 +1626,179 @@ object functions { */ def ascii(columnName: String): Column = ascii(Column(columnName)) + /** + * Trim the spaces from both ends for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def trim(e: Column): Column = StringTrim(e.expr) + + /** + * Trim the spaces from both ends for the specified column. + * + * @group string_funcs + * @since 1.5.0 + */ + def trim(columnName: String): Column = trim(Column(columnName)) + + /** + * Trim the spaces from left end for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def ltrim(e: Column): Column = StringTrimLeft(e.expr) + + /** + * Trim the spaces from left end for the specified column. + * + * @group string_funcs + * @since 1.5.0 + */ + def ltrim(columnName: String): Column = ltrim(Column(columnName)) + + /** + * Trim the spaces from right end for the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def rtrim(e: Column): Column = StringTrimRight(e.expr) + + /** + * Trim the spaces from right end for the specified column. + * + * @group string_funcs + * @since 1.5.0 + */ + def rtrim(columnName: String): Column = rtrim(Column(columnName)) + + /** + * Format strings in printf-style. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def formatString(format: Column, arguments: Column*): Column = { + StringFormat((format +: arguments).map(_.expr): _*) + } + + /** + * Format strings in printf-style. + * NOTE: `format` is the string value of the formatter, not column name. + * + * @group string_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def formatString(format: String, arguNames: String*): Column = { + StringFormat(lit(format).expr +: arguNames.map(Column(_).expr): _*) + } + + /** + * Locate the position of the first occurrence of substr value in the given string. + * Returns null if either of the arguments are null. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def instr(substr: String, sub: String): Column = instr(Column(substr), Column(sub)) + + /** + * Locate the position of the first occurrence of substr column in the given string. + * Returns null if either of the arguments are null. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def instr(substr: Column, sub: Column): Column = StringInstr(substr.expr, sub.expr) + + /** + * Locate the position of the first occurrence of substr. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: String, str: String): Column = { + locate(Column(substr), Column(str)) + } + + /** + * Locate the position of the first occurrence of substr. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: Column, str: Column): Column = { + new StringLocate(substr.expr, str.expr) + } + + /** + * Locate the position of the first occurrence of substr in a given string after position pos. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: String, str: String, pos: String): Column = { + locate(Column(substr), Column(str), Column(pos)) + } + + /** + * Locate the position of the first occurrence of substr in a given string after position pos. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: Column, str: Column, pos: Column): Column = { + StringLocate(substr.expr, str.expr, pos.expr) + } + + /** + * Locate the position of the first occurrence of substr in a given string after position pos. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: Column, str: Column, pos: Int): Column = { + StringLocate(substr.expr, str.expr, lit(pos).expr) + } + + /** + * Locate the position of the first occurrence of substr in a given string after position pos. + * + * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * could not be found in str. + * + * @group string_funcs + * @since 1.5.0 + */ + def locate(substr: String, str: String, pos: Int): Column = { + locate(Column(substr), Column(str), lit(pos)) + } + /** * Computes the specified value from binary to a base64 string. * @@ -1658,6 +1831,46 @@ object functions { */ def unbase64(columnName: String): Column = unbase64(Column(columnName)) + /** + * Left-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: String, len: String, pad: String): Column = { + lpad(Column(str), Column(len), Column(pad)) + } + + /** + * Left-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: Column, len: Column, pad: Column): Column = { + StringLPad(str.expr, len.expr, pad.expr) + } + + /** + * Left-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: Column, len: Int, pad: Column): Column = { + StringLPad(str.expr, lit(len).expr, pad.expr) + } + + /** + * Left-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def lpad(str: String, len: Int, pad: String): Column = { + lpad(Column(str), len, Column(pad)) + } + /** * Computes the first argument into a binary from a string using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). @@ -1702,6 +1915,146 @@ object functions { def decode(columnName: String, charset: String): Column = decode(Column(columnName), charset) + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: String, len: String, pad: String): Column = { + rpad(Column(str), Column(len), Column(pad)) + } + + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: Column, len: Column, pad: Column): Column = { + StringRPad(str.expr, len.expr, pad.expr) + } + + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: String, len: Int, pad: String): Column = { + rpad(Column(str), len, Column(pad)) + } + + /** + * Right-padded with pad to a length of len. + * + * @group string_funcs + * @since 1.5.0 + */ + def rpad(str: Column, len: Int, pad: Column): Column = { + StringRPad(str.expr, lit(len).expr, pad.expr) + } + + /** + * Repeat the string value of the specified column n times. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(strColumn: String, timesColumn: String): Column = { + repeat(Column(strColumn), Column(timesColumn)) + } + + /** + * Repeat the string expression value n times. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(str: Column, times: Column): Column = { + StringRepeat(str.expr, times.expr) + } + + /** + * Repeat the string value of the specified column n times. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(strColumn: String, times: Int): Column = { + repeat(Column(strColumn), times) + } + + /** + * Repeat the string expression value n times. + * + * @group string_funcs + * @since 1.5.0 + */ + def repeat(str: Column, times: Int): Column = { + StringRepeat(str.expr, lit(times).expr) + } + + /** + * Splits str around pattern (pattern is a regular expression). + * + * @group string_funcs + * @since 1.5.0 + */ + def split(strColumnName: String, pattern: String): Column = { + split(Column(strColumnName), pattern) + } + + /** + * Splits str around pattern (pattern is a regular expression). + * NOTE: pattern is a string represent the regular expression. + * + * @group string_funcs + * @since 1.5.0 + */ + def split(str: Column, pattern: String): Column = { + StringSplit(str.expr, lit(pattern).expr) + } + + /** + * Reversed the string for the specified column. + * + * @group string_funcs + * @since 1.5.0 + */ + def reverse(str: String): Column = { + reverse(Column(str)) + } + + /** + * Reversed the string for the specified value. + * + * @group string_funcs + * @since 1.5.0 + */ + def reverse(str: Column): Column = { + StringReverse(str.expr) + } + + /** + * Make a n spaces of string. + * + * @group string_funcs + * @since 1.5.0 + */ + def space(n: String): Column = { + space(Column(n)) + } + + /** + * Make a n spaces of string. + * + * @group string_funcs + * @since 1.5.0 + */ + def space(n: Column): Column = { + StringSpace(n.expr) + } ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index afba28515e032..173280375c411 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -209,21 +209,14 @@ class DataFrameFunctionsSuite extends QueryTest { } test("string length function") { + val df = Seq(("abc", "")).toDF("a", "b") checkAnswer( - nullStrings.select(strlen($"s"), strlen("s")), - nullStrings.collect().toSeq.map { r => - val v = r.getString(1) - val l = if (v == null) null else v.length - Row(l, l) - }) + df.select(strlen($"a"), strlen("b")), + Row(3, 0)) checkAnswer( - nullStrings.selectExpr("length(s)"), - nullStrings.collect().toSeq.map { r => - val v = r.getString(1) - val l = if (v == null) null else v.length - Row(l) - }) + df.selectExpr("length(a)", "length(b)"), + Row(3, 0)) } test("Levenshtein distance") { @@ -273,4 +266,119 @@ class DataFrameFunctionsSuite extends QueryTest { Row(bytes, "大千世界")) // scalastyle:on } + + test("string trim functions") { + val df = Seq((" example ", "")).toDF("a", "b") + + checkAnswer( + df.select(ltrim($"a"), rtrim($"a"), trim($"a")), + Row("example ", " example", "example")) + + checkAnswer( + df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), + Row("example ", " example", "example")) + } + + test("string formatString function") { + val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") + + checkAnswer( + df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), + Row("aa123cc", "aa123cc")) + + checkAnswer( + df.selectExpr("printf(a, b, c)"), + Row("aa123cc")) + } + + test("string instr function") { + val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c") + + checkAnswer( + df.select(instr($"a", $"b"), instr("a", "b")), + Row(1, 1)) + + checkAnswer( + df.selectExpr("instr(a, b)"), + Row(1)) + } + + test("string locate function") { + val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") + + checkAnswer( + df.select( + locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1), + locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")), + Row(1, 1, 2, 2, 2, 2)) + + checkAnswer( + df.selectExpr("locate(b, a)", "locate(b, a, d)"), + Row(1, 2)) + } + + test("string padding functions") { + val df = Seq(("hi", 5, "??")).toDF("a", "b", "c") + + checkAnswer( + df.select( + lpad($"a", $"b", $"c"), rpad("a", "b", "c"), + lpad($"a", 1, $"c"), rpad("a", 1, "c")), + Row("???hi", "hi???", "h", "h")) + + checkAnswer( + df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"), + Row("???hi", "hi???", "h", "h")) + } + + test("string repeat function") { + val df = Seq(("hi", 2)).toDF("a", "b") + + checkAnswer( + df.select( + repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")), + Row("hihi", "hihi", "hihi", "hihi")) + + checkAnswer( + df.selectExpr("repeat(a, 2)", "repeat(a, b)"), + Row("hihi", "hihi")) + } + + test("string reverse function") { + val df = Seq(("hi", "hhhi")).toDF("a", "b") + + checkAnswer( + df.select(reverse($"a"), reverse("b")), + Row("ih", "ihhh")) + + checkAnswer( + df.selectExpr("reverse(b)"), + Row("ihhh")) + } + + test("string space function") { + val df = Seq((2, 3)).toDF("a", "b") + + checkAnswer( + df.select(space($"a"), space("b")), + Row(" ", " ")) + + checkAnswer( + df.selectExpr("space(b)"), + Row(" ")) + } + + test("string split function") { + val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") + + checkAnswer( + df.select( + split($"a", "[1-9]+"), + split("a", "[1-9]+")), + Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc"))) + + checkAnswer( + df.selectExpr("split(a, '[1-9]+')"), + Row(Seq("aa", "bb", "cc"))) + } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 847d80ad583f6..60d050b0a0c97 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -25,6 +25,7 @@ import static org.apache.spark.unsafe.PlatformDependent.*; + /** * A UTF-8 String for internal Spark use. *

@@ -204,6 +205,196 @@ public UTF8String toLowerCase() { return fromString(toString().toLowerCase()); } + /** + * Copy the bytes from the current UTF8String, and make a new UTF8String. + * @param start the start position of the current UTF8String in bytes. + * @param end the end position of the current UTF8String in bytes. + * @return a new UTF8String in the position of [start, end] of current UTF8String bytes. + */ + private UTF8String copyUTF8String(int start, int end) { + int len = end - start + 1; + byte[] newBytes = new byte[len]; + copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); + return UTF8String.fromBytes(newBytes); + } + + public UTF8String trim() { + int s = 0; + int e = this.numBytes - 1; + // skip all of the space (0x20) in the left side + while (s < this.numBytes && getByte(s) == 0x20) s++; + // skip all of the space (0x20) in the right side + while (e >= 0 && getByte(e) == 0x20) e--; + + if (s > e) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(s, e); + } + } + + public UTF8String trimLeft() { + int s = 0; + // skip all of the space (0x20) in the left side + while (s < this.numBytes && getByte(s) == 0x20) s++; + if (s == this.numBytes) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(s, this.numBytes - 1); + } + } + + public UTF8String trimRight() { + int e = numBytes - 1; + // skip all of the space (0x20) in the right side + while (e >= 0 && getByte(e) == 0x20) e--; + + if (e < 0) { + // empty string + return UTF8String.fromBytes(new byte[0]); + } else { + return copyUTF8String(0, e); + } + } + + public UTF8String reverse() { + byte[] bytes = getBytes(); + byte[] result = new byte[bytes.length]; + + int i = 0; // position in byte + while (i < numBytes) { + int len = numBytesForFirstByte(getByte(i)); + System.arraycopy(bytes, i, result, result.length - i - len, len); + + i += len; + } + + return UTF8String.fromBytes(result); + } + + public UTF8String repeat(int times) { + if (times <=0) { + return fromBytes(new byte[0]); + } + + byte[] newBytes = new byte[numBytes * times]; + System.arraycopy(getBytes(), 0, newBytes, 0, numBytes); + + int copied = 1; + while (copied < times) { + int toCopy = Math.min(copied, times - copied); + System.arraycopy(newBytes, 0, newBytes, copied * numBytes, numBytes * toCopy); + copied += toCopy; + } + + return UTF8String.fromBytes(newBytes); + } + + /** + * Returns the position of the first occurrence of substr in + * current string from the specified position (0-based index). + * + * @param v the string to be searched + * @param start the start position of the current string for searching + * @return the position of the first occurrence of substr, if not found, -1 returned. + */ + public int indexOf(UTF8String v, int start) { + if (v.numBytes() == 0) { + return 0; + } + + // locate to the start position. + int i = 0; // position in byte + int c = 0; // position in character + while (i < numBytes && c < start) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + + do { + if (i + v.numBytes > numBytes) { + return -1; + } + if (ByteArrayMethods.arrayEquals(base, offset + i, v.base, v.offset, v.numBytes)) { + return c; + } + i += numBytesForFirstByte(getByte(i)); + c += 1; + } while(i < numBytes); + + return -1; + } + + /** + * Returns str, right-padded with pad to a length of len + * For example: + * ('hi', 5, '??') => 'hi???' + * ('hi', 1, '??') => 'h' + */ + public UTF8String rpad(int len, UTF8String pad) { + int spaces = len - this.numChars(); // number of char need to pad + if (spaces <= 0) { + // no padding at all, return the substring of the current string + return substring(0, len); + } else { + int padChars = pad.numChars(); + int count = spaces / padChars; // how many padding string needed + // the partial string of the padding + UTF8String remain = pad.substring(0, spaces - padChars * count); + + byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; + System.arraycopy(getBytes(), 0, data, 0, this.numBytes); + int offset = this.numBytes; + int idx = 0; + byte[] padBytes = pad.getBytes(); + while (idx < count) { + System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + ++idx; + offset += pad.numBytes; + } + System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + + return UTF8String.fromBytes(data); + } + } + + /** + * Returns str, left-padded with pad to a length of len. + * For example: + * ('hi', 5, '??') => '???hi' + * ('hi', 1, '??') => 'h' + */ + public UTF8String lpad(int len, UTF8String pad) { + int spaces = len - this.numChars(); // number of char need to pad + if (spaces <= 0) { + // no padding at all, return the substring of the current string + return substring(0, len); + } else { + int padChars = pad.numChars(); + int count = spaces / padChars; // how many padding string needed + // the partial string of the padding + UTF8String remain = pad.substring(0, spaces - padChars * count); + + byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; + + int offset = 0; + int idx = 0; + byte[] padBytes = pad.getBytes(); + while (idx < count) { + System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + ++idx; + offset += pad.numBytes; + } + System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + offset += remain.numBytes; + System.arraycopy(getBytes(), 0, data, offset, numBytes()); + + return UTF8String.fromBytes(data); + } + } + @Override public String toString() { try { diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index fb463ba17f50b..694bdc29f39d1 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -121,12 +121,94 @@ public void endsWith() { @Test public void substring() { - assertEquals(fromString("hello").substring(0, 0), fromString("")); - assertEquals(fromString("hello").substring(1, 3), fromString("el")); - assertEquals(fromString("数据砖头").substring(0, 1), fromString("数")); - assertEquals(fromString("数据砖头").substring(1, 3), fromString("据砖")); - assertEquals(fromString("数据砖头").substring(3, 5), fromString("头")); - assertEquals(fromString("ߵ梷").substring(0, 2), fromString("ߵ梷")); + assertEquals(fromString(""), fromString("hello").substring(0, 0)); + assertEquals(fromString("el"), fromString("hello").substring(1, 3)); + assertEquals(fromString("数"), fromString("数据砖头").substring(0, 1)); + assertEquals(fromString("据砖"), fromString("数据砖头").substring(1, 3)); + assertEquals(fromString("头"), fromString("数据砖头").substring(3, 5)); + assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0, 2)); + } + + @Test + public void trims() { + assertEquals(fromString("hello"), fromString(" hello ").trim()); + assertEquals(fromString("hello "), fromString(" hello ").trimLeft()); + assertEquals(fromString(" hello"), fromString(" hello ").trimRight()); + + assertEquals(fromString(""), fromString(" ").trim()); + assertEquals(fromString(""), fromString(" ").trimLeft()); + assertEquals(fromString(""), fromString(" ").trimRight()); + + assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); + assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft()); + assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight()); + + assertEquals(fromString("数据砖头"), fromString("数据砖头").trim()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimLeft()); + assertEquals(fromString("数据砖头"), fromString("数据砖头").trimRight()); + } + + @Test + public void indexOf() { + assertEquals(0, fromString("").indexOf(fromString(""), 0)); + assertEquals(-1, fromString("").indexOf(fromString("l"), 0)); + assertEquals(0, fromString("hello").indexOf(fromString(""), 0)); + assertEquals(2, fromString("hello").indexOf(fromString("l"), 0)); + assertEquals(3, fromString("hello").indexOf(fromString("l"), 3)); + assertEquals(-1, fromString("hello").indexOf(fromString("a"), 0)); + assertEquals(2, fromString("hello").indexOf(fromString("ll"), 0)); + assertEquals(-1, fromString("hello").indexOf(fromString("ll"), 4)); + assertEquals(1, fromString("数据砖头").indexOf(fromString("据砖"), 0)); + assertEquals(-1, fromString("数据砖头").indexOf(fromString("数"), 3)); + assertEquals(0, fromString("数据砖头").indexOf(fromString("数"), 0)); + assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0)); + } + + @Test + public void reverse() { + assertEquals(fromString("olleh"), fromString("hello").reverse()); + assertEquals(fromString(""), fromString("").reverse()); + assertEquals(fromString("者行孙"), fromString("孙行者").reverse()); + assertEquals(fromString("者行孙 olleh"), fromString("hello 孙行者").reverse()); + } + + @Test + public void repeat() { + assertEquals(fromString("数d数d数d数d数d"), fromString("数d").repeat(5)); + assertEquals(fromString("数d"), fromString("数d").repeat(1)); + assertEquals(fromString(""), fromString("数d").repeat(-1)); + } + + @Test + public void pad() { + assertEquals(fromString("hel"), fromString("hello").lpad(3, fromString("????"))); + assertEquals(fromString("hello"), fromString("hello").lpad(5, fromString("????"))); + assertEquals(fromString("?hello"), fromString("hello").lpad(6, fromString("????"))); + assertEquals(fromString("???????hello"), fromString("hello").lpad(12, fromString("????"))); + assertEquals(fromString("?????hello"), fromString("hello").lpad(10, fromString("?????"))); + assertEquals(fromString("???????"), fromString("").lpad(7, fromString("?????"))); + + assertEquals(fromString("hel"), fromString("hello").rpad(3, fromString("????"))); + assertEquals(fromString("hello"), fromString("hello").rpad(5, fromString("????"))); + assertEquals(fromString("hello?"), fromString("hello").rpad(6, fromString("????"))); + assertEquals(fromString("hello???????"), fromString("hello").rpad(12, fromString("????"))); + assertEquals(fromString("hello?????"), fromString("hello").rpad(10, fromString("?????"))); + assertEquals(fromString("???????"), fromString("").rpad(7, fromString("?????"))); + + + assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, fromString("????"))); + assertEquals(fromString("?数据砖头"), fromString("数据砖头").lpad(5, fromString("????"))); + assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????"))); + assertEquals(fromString("孙行数据砖头"), fromString("数据砖头").lpad(6, fromString("孙行者"))); + assertEquals(fromString("孙行者数据砖头"), fromString("数据砖头").lpad(7, fromString("孙行者"))); + assertEquals(fromString("孙行者孙行者孙行数据砖头"), fromString("数据砖头").lpad(12, fromString("孙行者"))); + + assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, fromString("????"))); + assertEquals(fromString("数据砖头?"), fromString("数据砖头").rpad(5, fromString("????"))); + assertEquals(fromString("数据砖头??"), fromString("数据砖头").rpad(6, fromString("????"))); + assertEquals(fromString("数据砖头孙行"), fromString("数据砖头").rpad(6, fromString("孙行者"))); + assertEquals(fromString("数据砖头孙行者"), fromString("数据砖头").rpad(7, fromString("孙行者"))); + assertEquals(fromString("数据砖头孙行者孙行者孙行"), fromString("数据砖头").rpad(12, fromString("孙行者"))); } @Test From 7ce3b818fb1ba3f291eda58988e4808e999cae3a Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 9 Jul 2015 13:19:36 -0700 Subject: [PATCH 085/149] [MINOR] [STREAMING] Fix log statements in ReceiverSupervisorImpl Log statements incorrectly showed that the executor was being stopped when receiver was being stopped. Author: Tathagata Das Closes #7328 from tdas/fix-log and squashes the following commits: 9cc6e99 [Tathagata Das] Fix log statements. --- .../spark/streaming/receiver/ReceiverSupervisor.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 33be067ebdaf2..eeb14ca3a49e9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -182,12 +182,12 @@ private[streaming] abstract class ReceiverSupervisor( /** Wait the thread until the supervisor is stopped */ def awaitTermination() { + logInfo("Waiting for receiver to be stopped") stopLatch.await() - logInfo("Waiting for executor stop is over") if (stoppingError != null) { - logError("Stopped executor with error: " + stoppingError) + logError("Stopped receiver with error: " + stoppingError) } else { - logWarning("Stopped executor without error") + logInfo("Stopped receiver without error") } if (stoppingError != null) { throw stoppingError From 930fe95350f8865e2af2d7afa5b717210933cd43 Mon Sep 17 00:00:00 2001 From: xutingjun Date: Thu, 9 Jul 2015 13:21:10 -0700 Subject: [PATCH 086/149] [SPARK-8953] SPARK_EXECUTOR_CORES is not read in SparkSubmit The configuration ```SPARK_EXECUTOR_CORES``` won't put into ```SparkConf```, so it has no effect to the dynamic executor allocation. Author: xutingjun Closes #7322 from XuTingjun/SPARK_EXECUTOR_CORES and squashes the following commits: 2cafa89 [xutingjun] make SPARK_EXECUTOR_CORES has effect to dynamicAllocation --- .../scala/org/apache/spark/deploy/SparkSubmitArguments.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 73ab18332feb4..6e3c0b21b33c2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -162,6 +162,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .orNull executorCores = Option(executorCores) .orElse(sparkProperties.get("spark.executor.cores")) + .orElse(env.get("SPARK_EXECUTOR_CORES")) .orNull totalExecutorCores = Option(totalExecutorCores) .orElse(sparkProperties.get("spark.cores.max")) From 88bf430331eef3c02438ca441616034486e15789 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 9 Jul 2015 13:22:17 -0700 Subject: [PATCH 087/149] [SPARK-7419] [STREAMING] [TESTS] Fix CheckpointSuite.recovery with file input stream Fix this failure: https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/2886/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.3,label=centos/testReport/junit/org.apache.spark.streaming/CheckpointSuite/recovery_with_file_input_stream/ To reproduce this failure, you can add `Thread.sleep(2000)` before this line https://github.com/apache/spark/blob/a9c4e29950a14e32acaac547e9a0e8879fd37fc9/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala#L477 Author: zsxwing Closes #7323 from zsxwing/SPARK-7419 and squashes the following commits: b3caf58 [zsxwing] Fix CheckpointSuite.recovery with file input stream --- .../spark/streaming/CheckpointSuite.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 6b0a3f91d4d06..6a94928076236 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -424,11 +424,11 @@ class CheckpointSuite extends TestSuiteBase { } } } - clock.advance(batchDuration.milliseconds) eventually(eventuallyTimeout) { // Wait until all files have been recorded and all batches have started assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3) } + clock.advance(batchDuration.milliseconds) // Wait for a checkpoint to be written eventually(eventuallyTimeout) { assert(Checkpoint.getCheckpointFiles(checkpointDir).size === 6) @@ -454,9 +454,12 @@ class CheckpointSuite extends TestSuiteBase { // recorded before failure were saved and successfully recovered logInfo("*********** RESTARTING ************") withStreamingContext(new StreamingContext(checkpointDir)) { ssc => - // So that the restarted StreamingContext's clock has gone forward in time since failure - ssc.conf.set("spark.streaming.manualClock.jump", (batchDuration * 3).milliseconds.toString) - val oldClockTime = clock.getTimeMillis() + // "batchDuration.milliseconds * 3" has gone before restarting StreamingContext. And because + // the recovery time is read from the checkpoint time but the original clock doesn't align + // with the batch time, we need to add the offset "batchDuration.milliseconds / 2". + ssc.conf.set("spark.streaming.manualClock.jump", + (batchDuration.milliseconds / 2 + batchDuration.milliseconds * 3).toString) + val oldClockTime = clock.getTimeMillis() // 15000ms clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val batchCounter = new BatchCounter(ssc) val outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] @@ -467,10 +470,10 @@ class CheckpointSuite extends TestSuiteBase { ssc.start() // Verify that the clock has traveled forward to the expected time eventually(eventuallyTimeout) { - clock.getTimeMillis() === oldClockTime + assert(clock.getTimeMillis() === oldClockTime) } - // Wait for pre-failure batch to be recomputed (3 while SSC was down plus last batch) - val numBatchesAfterRestart = 4 + // There are 5 batches between 6000ms and 15000ms (inclusive). + val numBatchesAfterRestart = 5 eventually(eventuallyTimeout) { assert(batchCounter.getNumCompletedBatches === numBatchesAfterRestart) } @@ -483,7 +486,6 @@ class CheckpointSuite extends TestSuiteBase { assert(batchCounter.getNumCompletedBatches === index + numBatchesAfterRestart + 1) } } - clock.advance(batchDuration.milliseconds) logInfo("Output after restart = " + outputStream.output.mkString("[", ", ", "]")) assert(outputStream.output.size > 0, "No files processed after restart") ssc.stop() From ebdf58538058e57381c04b6725d4be0c37847ed3 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 9 Jul 2015 13:25:11 -0700 Subject: [PATCH 088/149] [SPARK-2017] [UI] Stage page hangs with many tasks (This reopens a patch that was closed in the past: #6248) When you view the stage page while running the following: ``` sc.parallelize(1 to X, 10000).count() ``` The page never loads, the job is stalled, and you end up running into an OOM: ``` HTTP ERROR 500 Problem accessing /stages/stage/. Reason: Server Error Caused by: java.lang.OutOfMemoryError: Java heap space at java.util.Arrays.copyOf(Arrays.java:2367) at java.lang.AbstractStringBuilder.expandCapacity(AbstractStringBuilder.java:130) ``` This patch compresses Jetty responses in gzip. The correct long-term fix is to add pagination. Author: Andrew Or Closes #7296 from andrewor14/gzip-jetty and squashes the following commits: a051c64 [Andrew Or] Use GZIP to compress Jetty responses --- .../main/scala/org/apache/spark/ui/JettyUtils.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 06e616220c706..f413c1d37fbb6 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -210,10 +210,16 @@ private[spark] object JettyUtils extends Logging { conf: SparkConf, serverName: String = ""): ServerInfo = { - val collection = new ContextHandlerCollection - collection.setHandlers(handlers.toArray) addFilters(handlers, conf) + val collection = new ContextHandlerCollection + val gzipHandlers = handlers.map { h => + val gzipHandler = new GzipHandler + gzipHandler.setHandler(h) + gzipHandler + } + collection.setHandlers(gzipHandlers.toArray) + // Bind to the given port, or throw a java.net.BindException if the port is occupied def connect(currentPort: Int): (Server, Int) = { val server = new Server(new InetSocketAddress(hostName, currentPort)) From c4830598b271cc6390d127bd4cf8ab02b28792e0 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Thu, 9 Jul 2015 13:26:46 -0700 Subject: [PATCH 089/149] [SPARK-6287] [MESOS] Add dynamic allocation to the coarse-grained Mesos scheduler This is largely based on extracting the dynamic allocation parts from tnachen's #3861. Author: Iulian Dragos Closes #4984 from dragos/issue/mesos-coarse-dynamicAllocation and squashes the following commits: 39df8cd [Iulian Dragos] Update tests to latest changes in core. 9d2c9fa [Iulian Dragos] Remove adjustment of executorLimitOption in doKillExecutors. 8b00f52 [Iulian Dragos] Latest round of reviews. 0cd00e0 [Iulian Dragos] Add persistent shuffle directory 15c45c1 [Iulian Dragos] Add dynamic allocation to the Spark coarse-grained scheduler. --- .../scala/org/apache/spark/SparkContext.scala | 19 +- .../mesos/CoarseMesosSchedulerBackend.scala | 136 +++++++++++--- .../cluster/mesos/MesosSchedulerUtils.scala | 4 +- .../spark/storage/DiskBlockManager.scala | 8 +- .../scala/org/apache/spark/util/Utils.scala | 45 +++-- .../CoarseMesosSchedulerBackendSuite.scala | 175 ++++++++++++++++++ 6 files changed, 331 insertions(+), 56 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d2547eeff2b4e..82704b1ab2189 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -532,7 +532,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _executorAllocationManager = if (dynamicAllocationEnabled) { assert(supportDynamicAllocation, - "Dynamic allocation of executors is currently only supported in YARN mode") + "Dynamic allocation of executors is currently only supported in YARN and Mesos mode") Some(new ExecutorAllocationManager(this, listenerBus, _conf)) } else { None @@ -853,7 +853,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli minPartitions).setName(path) } - /** * :: Experimental :: * @@ -1364,10 +1363,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * Return whether dynamically adjusting the amount of resources allocated to - * this application is supported. This is currently only available for YARN. + * this application is supported. This is currently only available for YARN + * and Mesos coarse-grained mode. */ - private[spark] def supportDynamicAllocation = - master.contains("yarn") || _conf.getBoolean("spark.dynamicAllocation.testing", false) + private[spark] def supportDynamicAllocation: Boolean = { + (master.contains("yarn") + || master.contains("mesos") + || _conf.getBoolean("spark.dynamicAllocation.testing", false)) + } /** * :: DeveloperApi :: @@ -1385,7 +1388,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = { assert(supportDynamicAllocation, - "Requesting executors is currently only supported in YARN mode") + "Requesting executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.requestTotalExecutors(numExecutors) @@ -1403,7 +1406,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @DeveloperApi override def requestExecutors(numAdditionalExecutors: Int): Boolean = { assert(supportDynamicAllocation, - "Requesting executors is currently only supported in YARN mode") + "Requesting executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.requestExecutors(numAdditionalExecutors) @@ -1421,7 +1424,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @DeveloperApi override def killExecutors(executorIds: Seq[String]): Boolean = { assert(supportDynamicAllocation, - "Killing executors is currently only supported in YARN mode") + "Killing executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => b.killExecutors(executorIds) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index b68f8c7685eba..cbade131494bc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,11 +18,14 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{List => JList} +import java.util.{List => JList, Collections} +import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import com.google.common.collect.HashBiMap +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} @@ -60,9 +63,27 @@ private[spark] class CoarseMesosSchedulerBackend( val slaveIdsWithExecutors = new HashSet[String] - val taskIdToSlaveId = new HashMap[Int, String] - val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed + val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String] + // How many times tasks on each slave failed + val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int] + + /** + * The total number of executors we aim to have. Undefined when not using dynamic allocation + * and before the ExecutorAllocatorManager calls [[doRequesTotalExecutors]]. + */ + private var executorLimitOption: Option[Int] = None + + /** + * Return the current executor limit, which may be [[Int.MaxValue]] + * before properly initialized. + */ + private[mesos] def executorLimit: Int = executorLimitOption.getOrElse(Int.MaxValue) + + private val pendingRemovedSlaveIds = new HashSet[String] + // private lock object protecting mutable state above. Using the intrinsic lock + // may lead to deadlocks since the superclass might also try to lock + private val stateLock = new ReentrantLock val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) @@ -86,7 +107,7 @@ private[spark] class CoarseMesosSchedulerBackend( startScheduler(master, CoarseMesosSchedulerBackend.this, fwInfo) } - def createCommand(offer: Offer, numCores: Int): CommandInfo = { + def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = { val executorSparkHome = conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) .getOrElse { @@ -120,10 +141,6 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = sc.env.rpcEnv.uriOf( - SparkEnv.driverActorSystemName, - RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.getOption("spark.executor.uri") .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) @@ -133,7 +150,7 @@ private[spark] class CoarseMesosSchedulerBackend( command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" .format(prefixEnv, runScript) + - s" --driver-url $driverUrl" + + s" --driver-url $driverURL" + s" --executor-id ${offer.getSlaveId.getValue}" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + @@ -142,11 +159,12 @@ private[spark] class CoarseMesosSchedulerBackend( // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.get.split('/').last.split('.').head + val executorId = sparkExecutorId(offer.getSlaveId.getValue, taskId.toString) command.setValue( s"cd $basename*; $prefixEnv " + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + - s" --driver-url $driverUrl" + - s" --executor-id ${offer.getSlaveId.getValue}" + + s" --driver-url $driverURL" + + s" --executor-id $executorId" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") @@ -155,6 +173,17 @@ private[spark] class CoarseMesosSchedulerBackend( command.build() } + protected def driverURL: String = { + if (conf.contains("spark.testing")) { + "driverURL" + } else { + sc.env.rpcEnv.uriOf( + SparkEnv.driverActorSystemName, + RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + } + } + override def offerRescinded(d: SchedulerDriver, o: OfferID) {} override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { @@ -172,17 +201,18 @@ private[spark] class CoarseMesosSchedulerBackend( * unless we've already launched more than we wanted to. */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - synchronized { + stateLock.synchronized { val filters = Filters.newBuilder().setRefuseSeconds(5).build() for (offer <- offers) { val offerAttributes = toAttributeMap(offer.getAttributesList) val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) - val slaveId = offer.getSlaveId.toString + val slaveId = offer.getSlaveId.getValue val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt val id = offer.getId.getValue - if (meetsConstraints && + if (taskIdToSlaveId.size < executorLimit && totalCoresAcquired < maxCores && + meetsConstraints && mem >= calculateTotalMemory(sc) && cpus >= 1 && failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && @@ -197,7 +227,7 @@ private[spark] class CoarseMesosSchedulerBackend( val task = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) .setName("Task " + taskId) .addResources(createResource("cpus", cpusToUse)) .addResources(createResource("mem", calculateTotalMemory(sc))) @@ -209,7 +239,9 @@ private[spark] class CoarseMesosSchedulerBackend( // accept the offer and launch the task logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - d.launchTasks(List(offer.getId), List(task.build()), filters) + d.launchTasks( + Collections.singleton(offer.getId), + Collections.singleton(task.build()), filters) } else { // Decline the offer logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") @@ -224,7 +256,7 @@ private[spark] class CoarseMesosSchedulerBackend( val taskId = status.getTaskId.getValue.toInt val state = status.getState logInfo("Mesos task " + taskId + " is now " + state) - synchronized { + stateLock.synchronized { if (TaskState.isFinished(TaskState.fromMesos(state))) { val slaveId = taskIdToSlaveId(taskId) slaveIdsWithExecutors -= slaveId @@ -242,8 +274,9 @@ private[spark] class CoarseMesosSchedulerBackend( "is Spark installed on it?") } } + executorTerminated(d, slaveId, s"Executor finished with state $state") // In case we'd rejected everything before but have now lost a node - mesosDriver.reviveOffers() + d.reviveOffers() } } } @@ -262,18 +295,39 @@ private[spark] class CoarseMesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { - logInfo("Mesos slave lost: " + slaveId.getValue) - synchronized { - if (slaveIdsWithExecutors.contains(slaveId.getValue)) { - // Note that the slave ID corresponds to the executor ID on that slave - slaveIdsWithExecutors -= slaveId.getValue - removeExecutor(slaveId.getValue, "Mesos slave lost") + /** + * Called when a slave is lost or a Mesos task finished. Update local view on + * what tasks are running and remove the terminated slave from the list of pending + * slave IDs that we might have asked to be killed. It also notifies the driver + * that an executor was removed. + */ + private def executorTerminated(d: SchedulerDriver, slaveId: String, reason: String): Unit = { + stateLock.synchronized { + if (slaveIdsWithExecutors.contains(slaveId)) { + val slaveIdToTaskId = taskIdToSlaveId.inverse() + if (slaveIdToTaskId.contains(slaveId)) { + val taskId: Int = slaveIdToTaskId.get(slaveId) + taskIdToSlaveId.remove(taskId) + removeExecutor(sparkExecutorId(slaveId, taskId.toString), reason) + } + // TODO: This assumes one Spark executor per Mesos slave, + // which may no longer be true after SPARK-5095 + pendingRemovedSlaveIds -= slaveId + slaveIdsWithExecutors -= slaveId } } } - override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) { + private def sparkExecutorId(slaveId: String, taskId: String): String = { + s"$slaveId/$taskId" + } + + override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = { + logInfo("Mesos slave lost: " + slaveId.getValue) + executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue) + } + + override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = { logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) slaveLost(d, s) } @@ -284,4 +338,34 @@ private[spark] class CoarseMesosSchedulerBackend( super.applicationId } + override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + // We don't truly know if we can fulfill the full amount of executors + // since at coarse grain it depends on the amount of slaves available. + logInfo("Capping the total amount of executors to " + requestedTotal) + executorLimitOption = Some(requestedTotal) + true + } + + override def doKillExecutors(executorIds: Seq[String]): Boolean = { + if (mesosDriver == null) { + logWarning("Asked to kill executors before the Mesos driver was started.") + return false + } + + val slaveIdToTaskId = taskIdToSlaveId.inverse() + for (executorId <- executorIds) { + val slaveId = executorId.split("/")(0) + if (slaveIdToTaskId.contains(slaveId)) { + mesosDriver.killTask( + TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build()) + pendingRemovedSlaveIds += slaveId + } else { + logWarning("Unable to find executor Id '" + executorId + "' in Mesos scheduler") + } + } + // no need to adjust `executorLimitOption` since the AllocationManager already communicated + // the desired limit through a call to `doRequestTotalExecutors`. + // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]] + true + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index d8a8c848bb4d1..925702e63afd3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConversions._ import scala.util.control.NonFatal import com.google.common.base.Splitter -import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler} +import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos} import org.apache.mesos.Protos._ import org.apache.mesos.protobuf.GeneratedMessage import org.apache.spark.{Logging, SparkContext} @@ -39,7 +39,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { private final val registerLatch = new CountDownLatch(1) // Driver for talking to Mesos - protected var mesosDriver: MesosSchedulerDriver = null + protected var mesosDriver: SchedulerDriver = null /** * Starts the MesosSchedulerDriver with the provided information. This method returns diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 91ef86389a0c3..5f537692a16c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -124,10 +124,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon (blockId, getFile(blockId)) } + /** + * Create local directories for storing block data. These directories are + * located inside configured local directories and won't + * be deleted on JVM exit when using the external shuffle service. + */ private def createLocalDirs(conf: SparkConf): Array[File] = { - Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir => + Utils.getConfiguredLocalDirs(conf).flatMap { rootDir => try { val localDir = Utils.createDirectory(rootDir, "blockmgr") + Utils.chmod700(localDir) logInfo(s"Created local directory at $localDir") Some(localDir) } catch { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 944560a91354a..b6b932104a94d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -733,7 +733,12 @@ private[spark] object Utils extends Logging { localRootDirs } - private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = { + /** + * Return the configured local directories where Spark can write files. This + * method does not create any directories on its own, it only encapsulates the + * logic of locating the local directories according to deployment mode. + */ + def getConfiguredLocalDirs(conf: SparkConf): Array[String] = { if (isRunningInYarnContainer(conf)) { // If we are in yarn mode, systems can have different disk layouts so we must set it // to what Yarn on this system said was available. Note this assumes that Yarn has @@ -749,27 +754,29 @@ private[spark] object Utils extends Logging { Option(conf.getenv("SPARK_LOCAL_DIRS")) .getOrElse(conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) .split(",") - .flatMap { root => - try { - val rootDir = new File(root) - if (rootDir.exists || rootDir.mkdirs()) { - val dir = createTempDir(root) - chmod700(dir) - Some(dir.getAbsolutePath) - } else { - logError(s"Failed to create dir in $root. Ignoring this directory.") - None - } - } catch { - case e: IOException => - logError(s"Failed to create local root dir in $root. Ignoring this directory.") - None - } - } - .toArray } } + private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = { + getConfiguredLocalDirs(conf).flatMap { root => + try { + val rootDir = new File(root) + if (rootDir.exists || rootDir.mkdirs()) { + val dir = createTempDir(root) + chmod700(dir) + Some(dir.getAbsolutePath) + } else { + logError(s"Failed to create dir in $root. Ignoring this directory.") + None + } + } catch { + case e: IOException => + logError(s"Failed to create local root dir in $root. Ignoring this directory.") + None + } + }.toArray + } + /** Get the Yarn approved local directories. */ private def getYarnLocalDirs(conf: SparkConf): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala new file mode 100644 index 0000000000000..3f1692917a357 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.util +import java.util.Collections + +import org.apache.mesos.Protos.Value.Scalar +import org.apache.mesos.Protos._ +import org.apache.mesos.SchedulerDriver +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.Matchers +import org.scalatest.mock.MockitoSugar +import org.scalatest.BeforeAndAfter + +import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} + +class CoarseMesosSchedulerBackendSuite extends SparkFunSuite + with LocalSparkContext + with MockitoSugar + with BeforeAndAfter { + + private def createOffer(offerId: String, slaveId: String, mem: Int, cpu: Int): Offer = { + val builder = Offer.newBuilder() + builder.addResourcesBuilder() + .setName("mem") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(mem)) + builder.addResourcesBuilder() + .setName("cpus") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(cpu)) + builder.setId(OfferID.newBuilder() + .setValue(offerId).build()) + .setFrameworkId(FrameworkID.newBuilder() + .setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue(slaveId)) + .setHostname(s"host${slaveId}") + .build() + } + + private def createSchedulerBackend( + taskScheduler: TaskSchedulerImpl, + driver: SchedulerDriver): CoarseMesosSchedulerBackend = { + val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master") { + mesosDriver = driver + markRegistered() + } + backend.start() + backend + } + + var sparkConf: SparkConf = _ + + before { + sparkConf = (new SparkConf) + .setMaster("local[*]") + .setAppName("test-mesos-dynamic-alloc") + .setSparkHome("/path") + + sc = new SparkContext(sparkConf) + } + + test("mesos supports killing and limiting executors") { + val driver = mock[SchedulerDriver] + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + + sparkConf.set("spark.driver.host", "driverHost") + sparkConf.set("spark.driver.port", "1234") + + val backend = createSchedulerBackend(taskScheduler, driver) + val minMem = backend.calculateTotalMemory(sc).toInt + val minCpu = 4 + + val mesosOffers = new java.util.ArrayList[Offer] + mesosOffers.add(createOffer("o1", "s1", minMem, minCpu)) + + val taskID0 = TaskID.newBuilder().setValue("0").build() + + backend.resourceOffers(driver, mesosOffers) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + any[util.Collection[TaskInfo]], + any[Filters]) + + // simulate the allocation manager down-scaling executors + backend.doRequestTotalExecutors(0) + assert(backend.doKillExecutors(Seq("s1/0"))) + verify(driver, times(1)).killTask(taskID0) + + val mesosOffers2 = new java.util.ArrayList[Offer] + mesosOffers2.add(createOffer("o2", "s2", minMem, minCpu)) + backend.resourceOffers(driver, mesosOffers2) + + verify(driver, times(1)) + .declineOffer(OfferID.newBuilder().setValue("o2").build()) + + // Verify we didn't launch any new executor + assert(backend.slaveIdsWithExecutors.size === 1) + + backend.doRequestTotalExecutors(2) + backend.resourceOffers(driver, mesosOffers2) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers2.get(0).getId)), + any[util.Collection[TaskInfo]], + any[Filters]) + + assert(backend.slaveIdsWithExecutors.size === 2) + backend.slaveLost(driver, SlaveID.newBuilder().setValue("s1").build()) + assert(backend.slaveIdsWithExecutors.size === 1) + } + + test("mesos supports killing and relaunching tasks with executors") { + val driver = mock[SchedulerDriver] + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + + val backend = createSchedulerBackend(taskScheduler, driver) + val minMem = backend.calculateTotalMemory(sc).toInt + 1024 + val minCpu = 4 + + val mesosOffers = new java.util.ArrayList[Offer] + val offer1 = createOffer("o1", "s1", minMem, minCpu) + mesosOffers.add(offer1) + + val offer2 = createOffer("o2", "s1", minMem, 1); + + backend.resourceOffers(driver, mesosOffers) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(offer1.getId)), + anyObject(), + anyObject[Filters]) + + // Simulate task killed, executor no longer running + val status = TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue("0").build()) + .setSlaveId(SlaveID.newBuilder().setValue("s1").build()) + .setState(TaskState.TASK_KILLED) + .build + + backend.statusUpdate(driver, status) + assert(!backend.slaveIdsWithExecutors.contains("s1")) + + mesosOffers.clear() + mesosOffers.add(offer2) + backend.resourceOffers(driver, mesosOffers) + assert(backend.slaveIdsWithExecutors.contains("s1")) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(offer2.getId)), + anyObject(), + anyObject[Filters]) + + verify(driver, times(1)).reviveOffers() + } +} From 1f6b0b1234cc03aa2e07aea7fec2de7563885238 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 9 Jul 2015 13:48:29 -0700 Subject: [PATCH 090/149] [SPARK-8701] [STREAMING] [WEBUI] Add input metadata in the batch page This PR adds `metadata` to `InputInfo`. `InputDStream` can report its metadata for a batch and it will be shown in the batch page. For example, ![screen shot](https://cloud.githubusercontent.com/assets/1000778/8403741/d6ffc7e2-1e79-11e5-9888-c78c1575123a.png) FileInputDStream will display the new files for a batch, and DirectKafkaInputDStream will display its offset ranges. Author: zsxwing Closes #7081 from zsxwing/input-metadata and squashes the following commits: f7abd9b [zsxwing] Revert the space changes in project/MimaExcludes.scala d906209 [zsxwing] Merge branch 'master' into input-metadata 74762da [zsxwing] Fix MiMa tests 7903e33 [zsxwing] Merge branch 'master' into input-metadata 450a46c [zsxwing] Address comments 1d94582 [zsxwing] Raname InputInfo to StreamInputInfo and change "metadata" to Map[String, Any] d496ae9 [zsxwing] Add input metadata in the batch page --- .../kafka/DirectKafkaInputDStream.scala | 23 ++++++++-- .../spark/streaming/kafka/OffsetRange.scala | 2 +- project/MimaExcludes.scala | 6 +++ .../streaming/dstream/FileInputDStream.scala | 10 ++++- .../dstream/ReceiverInputDStream.scala | 4 +- .../spark/streaming/scheduler/BatchInfo.scala | 9 ++-- .../scheduler/InputInfoTracker.scala | 38 +++++++++++++--- .../streaming/scheduler/JobGenerator.scala | 3 +- .../spark/streaming/scheduler/JobSet.scala | 4 +- .../apache/spark/streaming/ui/BatchPage.scala | 43 +++++++++++++++++-- .../spark/streaming/ui/BatchUIData.scala | 8 ++-- .../ui/StreamingJobProgressListener.scala | 5 ++- .../streaming/StreamingListenerSuite.scala | 6 +-- .../spark/streaming/TestSuiteBase.scala | 2 +- .../scheduler/InputInfoTrackerSuite.scala | 8 ++-- .../StreamingJobProgressListenerSuite.scala | 28 ++++++------ 16 files changed, 148 insertions(+), 51 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 876456c964770..48a1933d92f85 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.kafka import scala.annotation.tailrec import scala.collection.mutable -import scala.reflect.{classTag, ClassTag} +import scala.reflect.ClassTag import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata @@ -29,7 +29,7 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset -import org.apache.spark.streaming.scheduler.InputInfo +import org.apache.spark.streaming.scheduler.StreamInputInfo /** * A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where @@ -119,8 +119,23 @@ class DirectKafkaInputDStream[ val rdd = KafkaRDD[K, V, U, T, R]( context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) - // Report the record number of this batch interval to InputInfoTracker. - val inputInfo = InputInfo(id, rdd.count) + // Report the record number and metadata of this batch interval to InputInfoTracker. + val offsetRanges = currentOffsets.map { case (tp, fo) => + val uo = untilOffsets(tp) + OffsetRange(tp.topic, tp.partition, fo, uo.offset) + } + val description = offsetRanges.filter { offsetRange => + // Don't display empty ranges. + offsetRange.fromOffset != offsetRange.untilOffset + }.map { offsetRange => + s"topic: ${offsetRange.topic}\tpartition: ${offsetRange.partition}\t" + + s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}" + }.mkString("\n") + // Copy offsetRanges to immutable.List to prevent from being modified by the user + val metadata = Map( + "offsets" -> offsetRanges.toList, + StreamInputInfo.METADATA_KEY_DESCRIPTION -> description) + val inputInfo = StreamInputInfo(id, rdd.count, metadata) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index 2675042666304..f326e7f1f6f8d 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -75,7 +75,7 @@ final class OffsetRange private( } override def toString(): String = { - s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset]" + s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset])" } /** this is to avoid ClassNotFoundException during checkpoint restore */ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 821aadd477ef3..79089aae2a37c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -77,6 +77,12 @@ object MimaExcludes { // SPARK-8914 Remove RDDApi ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.RDDApi") + ) ++ Seq( + // SPARK-8701 Add input metadata in the batch page. + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.streaming.scheduler.InputInfo$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.streaming.scheduler.InputInfo") ) case v if v.startsWith("1.4") => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 86a8e2beff57c..dd4da9d9ca6a2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming._ +import org.apache.spark.streaming.scheduler.StreamInputInfo import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Utils} /** @@ -144,7 +145,14 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( logInfo("New files at time " + validTime + ":\n" + newFiles.mkString("\n")) batchTimeToSelectedFiles += ((validTime, newFiles)) recentlySelectedFiles ++= newFiles - Some(filesToRDD(newFiles)) + val rdds = Some(filesToRDD(newFiles)) + // Copy newFiles to immutable.List to prevent from being modified by the user + val metadata = Map( + "files" -> newFiles.toList, + StreamInputInfo.METADATA_KEY_DESCRIPTION -> newFiles.mkString("\n")) + val inputInfo = StreamInputInfo(id, 0, metadata) + ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) + rdds } /** Clear the old time-to-files mappings along with old RDDs */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index e76e7eb0dea19..a50f0efc030ce 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -24,7 +24,7 @@ import org.apache.spark.storage.BlockId import org.apache.spark.streaming._ import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.streaming.scheduler.InputInfo +import org.apache.spark.streaming.scheduler.StreamInputInfo import org.apache.spark.streaming.util.WriteAheadLogUtils /** @@ -70,7 +70,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont val blockIds = blockInfos.map { _.blockId.asInstanceOf[BlockId] }.toArray // Register the input blocks information into InputInfoTracker - val inputInfo = InputInfo(id, blockInfos.flatMap(_.numRecords).sum) + val inputInfo = StreamInputInfo(id, blockInfos.flatMap(_.numRecords).sum) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) if (blockInfos.nonEmpty) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index 5b9bfbf9b01e3..9922b6bc1201b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -24,7 +24,7 @@ import org.apache.spark.streaming.Time * :: DeveloperApi :: * Class having information on completed batches. * @param batchTime Time of the batch - * @param streamIdToNumRecords A map of input stream id to record number + * @param streamIdToInputInfo A map of input stream id to its input info * @param submissionTime Clock time of when jobs of this batch was submitted to * the streaming scheduler queue * @param processingStartTime Clock time of when the first job of this batch started processing @@ -33,12 +33,15 @@ import org.apache.spark.streaming.Time @DeveloperApi case class BatchInfo( batchTime: Time, - streamIdToNumRecords: Map[Int, Long], + streamIdToInputInfo: Map[Int, StreamInputInfo], submissionTime: Long, processingStartTime: Option[Long], processingEndTime: Option[Long] ) { + @deprecated("Use streamIdToInputInfo instead", "1.5.0") + def streamIdToNumRecords: Map[Int, Long] = streamIdToInputInfo.mapValues(_.numRecords) + /** * Time taken for the first job of this batch to start processing from the time this batch * was submitted to the streaming scheduler. Essentially, it is @@ -63,5 +66,5 @@ case class BatchInfo( /** * The number of recorders received by the receivers in this batch. */ - def numRecords: Long = streamIdToNumRecords.values.sum + def numRecords: Long = streamIdToInputInfo.values.map(_.numRecords).sum } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index 7c0db8a863c67..363c03d431f04 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -20,11 +20,34 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.streaming.{Time, StreamingContext} -/** To track the information of input stream at specified batch time. */ -private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) { +/** + * :: DeveloperApi :: + * Track the information of input stream at specified batch time. + * + * @param inputStreamId the input stream id + * @param numRecords the number of records in a batch + * @param metadata metadata for this batch. It should contain at least one standard field named + * "Description" which maps to the content that will be shown in the UI. + */ +@DeveloperApi +case class StreamInputInfo( + inputStreamId: Int, numRecords: Long, metadata: Map[String, Any] = Map.empty) { require(numRecords >= 0, "numRecords must not be negative") + + def metadataDescription: Option[String] = + metadata.get(StreamInputInfo.METADATA_KEY_DESCRIPTION).map(_.toString) +} + +@DeveloperApi +object StreamInputInfo { + + /** + * The key for description in `StreamInputInfo.metadata`. + */ + val METADATA_KEY_DESCRIPTION: String = "Description" } /** @@ -34,12 +57,13 @@ private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) { private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging { // Map to track all the InputInfo related to specific batch time and input stream. - private val batchTimeToInputInfos = new mutable.HashMap[Time, mutable.HashMap[Int, InputInfo]] + private val batchTimeToInputInfos = + new mutable.HashMap[Time, mutable.HashMap[Int, StreamInputInfo]] /** Report the input information with batch time to the tracker */ - def reportInfo(batchTime: Time, inputInfo: InputInfo): Unit = synchronized { + def reportInfo(batchTime: Time, inputInfo: StreamInputInfo): Unit = synchronized { val inputInfos = batchTimeToInputInfos.getOrElseUpdate(batchTime, - new mutable.HashMap[Int, InputInfo]()) + new mutable.HashMap[Int, StreamInputInfo]()) if (inputInfos.contains(inputInfo.inputStreamId)) { throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId}} for batch" + @@ -49,10 +73,10 @@ private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging } /** Get the all the input stream's information of specified batch time */ - def getInfo(batchTime: Time): Map[Int, InputInfo] = synchronized { + def getInfo(batchTime: Time): Map[Int, StreamInputInfo] = synchronized { val inputInfos = batchTimeToInputInfos.get(batchTime) // Convert mutable HashMap to immutable Map for the caller - inputInfos.map(_.toMap).getOrElse(Map[Int, InputInfo]()) + inputInfos.map(_.toMap).getOrElse(Map[Int, StreamInputInfo]()) } /** Cleanup the tracked input information older than threshold batch time */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 9f93d6cbc3c20..f5d41858646e4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -244,8 +244,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } match { case Success(jobs) => val streamIdToInputInfos = jobScheduler.inputInfoTracker.getInfo(time) - val streamIdToNumRecords = streamIdToInputInfos.mapValues(_.numRecords) - jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToNumRecords)) + jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToInputInfos)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index e6be63b2ddbdc..95833efc9417f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -28,7 +28,7 @@ private[streaming] case class JobSet( time: Time, jobs: Seq[Job], - streamIdToNumRecords: Map[Int, Long] = Map.empty) { + streamIdToInputInfo: Map[Int, StreamInputInfo] = Map.empty) { private val incompleteJobs = new HashSet[Job]() private val submissionTime = System.currentTimeMillis() // when this jobset was submitted @@ -64,7 +64,7 @@ case class JobSet( def toBatchInfo: BatchInfo = { new BatchInfo( time, - streamIdToNumRecords, + streamIdToInputInfo, submissionTime, if (processingStartTime >= 0 ) Some(processingStartTime) else None, if (processingEndTime >= 0 ) Some(processingEndTime) else None diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index f75067669abe5..0c891662c264f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -17,11 +17,9 @@ package org.apache.spark.streaming.ui -import java.text.SimpleDateFormat -import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.xml.{NodeSeq, Node, Text} +import scala.xml.{NodeSeq, Node, Text, Unparsed} import org.apache.commons.lang3.StringEscapeUtils @@ -303,6 +301,9 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { batchUIData.processingDelay.map(SparkUIUtils.formatDuration).getOrElse("-") val formattedTotalDelay = batchUIData.totalDelay.map(SparkUIUtils.formatDuration).getOrElse("-") + val inputMetadatas = batchUIData.streamIdToInputInfo.values.flatMap { inputInfo => + inputInfo.metadataDescription.map(desc => inputInfo.inputStreamId -> desc) + }.toSeq val summary: NodeSeq =

    @@ -326,6 +327,13 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { Total delay: {formattedTotalDelay} + { + if (inputMetadatas.nonEmpty) { +
  • + Input Metadata:{generateInputMetadataTable(inputMetadatas)} +
  • + } + }
@@ -340,4 +348,33 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { SparkUIUtils.headerSparkPage(s"Details of batch at $formattedBatchTime", content, parent) } + + def generateInputMetadataTable(inputMetadatas: Seq[(Int, String)]): Seq[Node] = { + + + + + + + + + {inputMetadatas.flatMap(generateInputMetadataRow)} + +
InputMetadata
+ } + + def generateInputMetadataRow(inputMetadata: (Int, String)): Seq[Node] = { + val streamId = inputMetadata._1 + + + {streamingListener.streamName(streamId).getOrElse(s"Stream-$streamId")} + {metadataDescriptionToHTML(inputMetadata._2)} + + } + + private def metadataDescriptionToHTML(metadataDescription: String): Seq[Node] = { + // tab to 4 spaces and "\n" to "
" + Unparsed(StringEscapeUtils.escapeHtml4(metadataDescription). + replaceAllLiterally("\t", "    ").replaceAllLiterally("\n", "
")) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala index a5514dfd71c9f..ae508c0e9577b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala @@ -19,14 +19,14 @@ package org.apache.spark.streaming.ui import org.apache.spark.streaming.Time -import org.apache.spark.streaming.scheduler.BatchInfo +import org.apache.spark.streaming.scheduler.{BatchInfo, StreamInputInfo} import org.apache.spark.streaming.ui.StreamingJobProgressListener._ private[ui] case class OutputOpIdAndSparkJobId(outputOpId: OutputOpId, sparkJobId: SparkJobId) private[ui] case class BatchUIData( val batchTime: Time, - val streamIdToNumRecords: Map[Int, Long], + val streamIdToInputInfo: Map[Int, StreamInputInfo], val submissionTime: Long, val processingStartTime: Option[Long], val processingEndTime: Option[Long], @@ -58,7 +58,7 @@ private[ui] case class BatchUIData( /** * The number of recorders received by the receivers in this batch. */ - def numRecords: Long = streamIdToNumRecords.values.sum + def numRecords: Long = streamIdToInputInfo.values.map(_.numRecords).sum } private[ui] object BatchUIData { @@ -66,7 +66,7 @@ private[ui] object BatchUIData { def apply(batchInfo: BatchInfo): BatchUIData = { new BatchUIData( batchInfo.batchTime, - batchInfo.streamIdToNumRecords, + batchInfo.streamIdToInputInfo, batchInfo.submissionTime, batchInfo.processingStartTime, batchInfo.processingEndTime diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 68e8ce98945e0..b77c555c68b8b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -192,7 +192,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) def receivedEventRateWithBatchTime: Map[Int, Seq[(Long, Double)]] = synchronized { val _retainedBatches = retainedBatches val latestBatches = _retainedBatches.map { batchUIData => - (batchUIData.batchTime.milliseconds, batchUIData.streamIdToNumRecords) + (batchUIData.batchTime.milliseconds, batchUIData.streamIdToInputInfo.mapValues(_.numRecords)) } streamIds.map { streamId => val eventRates = latestBatches.map { @@ -205,7 +205,8 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } def lastReceivedBatchRecords: Map[Int, Long] = synchronized { - val lastReceivedBlockInfoOption = lastReceivedBatch.map(_.streamIdToNumRecords) + val lastReceivedBlockInfoOption = + lastReceivedBatch.map(_.streamIdToInputInfo.mapValues(_.numRecords)) lastReceivedBlockInfoOption.map { lastReceivedBlockInfo => streamIds.map { streamId => (streamId, lastReceivedBlockInfo.getOrElse(streamId, 0L)) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 7bc7727a9fbe4..4bc1dd4a30fc4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -59,7 +59,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { batchInfosSubmitted.foreach { info => info.numRecords should be (1L) - info.streamIdToNumRecords should be (Map(0 -> 1L)) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } isInIncreasingOrder(batchInfosSubmitted.map(_.submissionTime)) should be (true) @@ -77,7 +77,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { batchInfosStarted.foreach { info => info.numRecords should be (1L) - info.streamIdToNumRecords should be (Map(0 -> 1L)) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } isInIncreasingOrder(batchInfosStarted.map(_.submissionTime)) should be (true) @@ -98,7 +98,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { batchInfosCompleted.foreach { info => info.numRecords should be (1L) - info.streamIdToNumRecords should be (Map(0 -> 1L)) + info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } isInIncreasingOrder(batchInfosCompleted.map(_.submissionTime)) should be (true) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 31b1aebf6a8ec..0d58a7b54412f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -76,7 +76,7 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], } // Report the input data's information to InputInfoTracker for testing - val inputInfo = InputInfo(id, selectedInput.length.toLong) + val inputInfo = StreamInputInfo(id, selectedInput.length.toLong) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala index 2e210397fe7c7..f5248acf712b9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -46,8 +46,8 @@ class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { val streamId1 = 0 val streamId2 = 1 val time = Time(0L) - val inputInfo1 = InputInfo(streamId1, 100L) - val inputInfo2 = InputInfo(streamId2, 300L) + val inputInfo1 = StreamInputInfo(streamId1, 100L) + val inputInfo2 = StreamInputInfo(streamId2, 300L) inputInfoTracker.reportInfo(time, inputInfo1) inputInfoTracker.reportInfo(time, inputInfo2) @@ -63,8 +63,8 @@ class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { val inputInfoTracker = new InputInfoTracker(ssc) val streamId1 = 0 - val inputInfo1 = InputInfo(streamId1, 100L) - val inputInfo2 = InputInfo(streamId1, 300L) + val inputInfo1 = StreamInputInfo(streamId1, 100L) + val inputInfo2 = StreamInputInfo(streamId1, 300L) inputInfoTracker.reportInfo(Time(0), inputInfo1) inputInfoTracker.reportInfo(Time(1), inputInfo2) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index c9175d61b1f49..40dc1fb601bd0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -49,10 +49,12 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) - val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + val streamIdToInputInfo = Map( + 0 -> StreamInputInfo(0, 300L), + 1 -> StreamInputInfo(1, 300L, Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "test"))) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, None, None) + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted))) listener.runningBatches should be (Nil) @@ -64,7 +66,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (0) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) @@ -94,7 +96,9 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { batchUIData.get.schedulingDelay should be (batchInfoStarted.schedulingDelay) batchUIData.get.processingDelay should be (batchInfoStarted.processingDelay) batchUIData.get.totalDelay should be (batchInfoStarted.totalDelay) - batchUIData.get.streamIdToNumRecords should be (Map(0 -> 300L, 1 -> 300L)) + batchUIData.get.streamIdToInputInfo should be (Map( + 0 -> StreamInputInfo(0, 300L), + 1 -> StreamInputInfo(1, 300L, Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "test")))) batchUIData.get.numRecords should be(600) batchUIData.get.outputOpIdSparkJobIdPairs should be Seq(OutputOpIdAndSparkJobId(0, 0), @@ -103,7 +107,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { OutputOpIdAndSparkJobId(1, 1)) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (Nil) @@ -141,9 +145,9 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) - val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) for(_ <- 0 until (limit + 10)) { listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) @@ -182,7 +186,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { batchUIData.get.schedulingDelay should be (batchInfoSubmitted.schedulingDelay) batchUIData.get.processingDelay should be (batchInfoSubmitted.processingDelay) batchUIData.get.totalDelay should be (batchInfoSubmitted.totalDelay) - batchUIData.get.streamIdToNumRecords should be (Map.empty) + batchUIData.get.streamIdToInputInfo should be (Map.empty) batchUIData.get.numRecords should be (0) batchUIData.get.outputOpIdSparkJobIdPairs should be (Seq(OutputOpIdAndSparkJobId(0, 0))) @@ -211,14 +215,14 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) for (_ <- 0 until 2 * limit) { - val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, None, None) + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) // onJobStart @@ -235,7 +239,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.onJobStart(jobStart4) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) } From 3ccebf36c5abe04702d4cf223552a94034d980fb Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 9 Jul 2015 13:54:44 -0700 Subject: [PATCH 091/149] [SPARK-8389] [STREAMING] [PYSPARK] Expose KafkaRDDs offsetRange in Python This PR propose a simple way to expose OffsetRange in Python code, also the usage of offsetRanges is similar to Scala/Java way, here in Python we could get OffsetRange like: ``` dstream.foreachRDD(lambda r: KafkaUtils.offsetRanges(r)) ``` Reason I didn't follow the way what SPARK-8389 suggested is that: Python Kafka API has one more step to decode the message compared to Scala/Java, Which makes Python API return a transformed RDD/DStream, not directly wrapped so-called JavaKafkaRDD, so it is hard to backtrack to the original RDD to get the offsetRange. Author: jerryshao Closes #7185 from jerryshao/SPARK-8389 and squashes the following commits: 4c6d320 [jerryshao] Another way to fix subclass deserialization issue e6a8011 [jerryshao] Address the comments fd13937 [jerryshao] Fix serialization bug 7debf1c [jerryshao] bug fix cff3893 [jerryshao] refactor the code according to the comments 2aabf9e [jerryshao] Style fix 848c708 [jerryshao] Add HasOffsetRanges for Python --- .../spark/streaming/kafka/KafkaUtils.scala | 13 ++ python/pyspark/streaming/kafka.py | 123 ++++++++++++++++-- python/pyspark/streaming/tests.py | 64 +++++++++ python/pyspark/streaming/util.py | 7 +- 4 files changed, 196 insertions(+), 11 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 0e33362d34acd..f3b01bd60b178 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -670,4 +670,17 @@ private class KafkaUtilsPythonHelper { TopicAndPartition(topic, partition) def createBroker(host: String, port: JInt): Broker = Broker(host, port) + + def offsetRangesOfKafkaRDD(rdd: RDD[_]): JList[OffsetRange] = { + val parentRDDs = rdd.getNarrowAncestors + val kafkaRDDs = parentRDDs.filter(rdd => rdd.isInstanceOf[KafkaRDD[_, _, _, _, _]]) + + require( + kafkaRDDs.length == 1, + "Cannot get offset ranges, as there may be multiple Kafka RDDs or no Kafka RDD associated" + + "with this RDD, please call this method only on a Kafka RDD.") + + val kafkaRDD = kafkaRDDs.head.asInstanceOf[KafkaRDD[_, _, _, _, _]] + kafkaRDD.offsetRanges.toSeq + } } diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 10a859a532e28..33dd596335b47 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -21,6 +21,8 @@ from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer from pyspark.streaming import DStream +from pyspark.streaming.dstream import TransformedDStream +from pyspark.streaming.util import TransformFunction __all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder'] @@ -122,8 +124,9 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, raise e ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - stream = DStream(jstream, ssc, ser) - return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + stream = DStream(jstream, ssc, ser) \ + .map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer) @staticmethod def createRDD(sc, kafkaParams, offsetRanges, leaders={}, @@ -161,8 +164,8 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={}, raise e ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - rdd = RDD(jrdd, sc, ser) - return rdd.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + rdd = RDD(jrdd, sc, ser).map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + return KafkaRDD(rdd._jrdd, rdd.ctx, rdd._jrdd_deserializer) @staticmethod def _printErrorMsg(sc): @@ -200,14 +203,30 @@ def __init__(self, topic, partition, fromOffset, untilOffset): :param fromOffset: Inclusive starting offset. :param untilOffset: Exclusive ending offset. """ - self._topic = topic - self._partition = partition - self._fromOffset = fromOffset - self._untilOffset = untilOffset + self.topic = topic + self.partition = partition + self.fromOffset = fromOffset + self.untilOffset = untilOffset + + def __eq__(self, other): + if isinstance(other, self.__class__): + return (self.topic == other.topic + and self.partition == other.partition + and self.fromOffset == other.fromOffset + and self.untilOffset == other.untilOffset) + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return "OffsetRange(topic: %s, partition: %d, range: [%d -> %d]" \ + % (self.topic, self.partition, self.fromOffset, self.untilOffset) def _jOffsetRange(self, helper): - return helper.createOffsetRange(self._topic, self._partition, self._fromOffset, - self._untilOffset) + return helper.createOffsetRange(self.topic, self.partition, self.fromOffset, + self.untilOffset) class TopicAndPartition(object): @@ -244,3 +263,87 @@ def __init__(self, host, port): def _jBroker(self, helper): return helper.createBroker(self._host, self._port) + + +class KafkaRDD(RDD): + """ + A Python wrapper of KafkaRDD, to provide additional information on normal RDD. + """ + + def __init__(self, jrdd, ctx, jrdd_deserializer): + RDD.__init__(self, jrdd, ctx, jrdd_deserializer) + + def offsetRanges(self): + """ + Get the OffsetRange of specific KafkaRDD. + :return: A list of OffsetRange + """ + try: + helperClass = self.ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd()) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(self.ctx) + raise e + + ranges = [OffsetRange(o.topic(), o.partition(), o.fromOffset(), o.untilOffset()) + for o in joffsetRanges] + return ranges + + +class KafkaDStream(DStream): + """ + A Python wrapper of KafkaDStream + """ + + def __init__(self, jdstream, ssc, jrdd_deserializer): + DStream.__init__(self, jdstream, ssc, jrdd_deserializer) + + def foreachRDD(self, func): + """ + Apply a function to each RDD in this DStream. + """ + if func.__code__.co_argcount == 1: + old_func = func + func = lambda r, rdd: old_func(rdd) + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) \ + .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser)) + api = self._ssc._jvm.PythonDStream + api.callForeachRDD(self._jdstream, jfunc) + + def transform(self, func): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of this DStream. + + `func` can have one argument of `rdd`, or have two arguments of + (`time`, `rdd`) + """ + if func.__code__.co_argcount == 1: + oldfunc = func + func = lambda t, rdd: oldfunc(rdd) + assert func.__code__.co_argcount == 2, "func should take one or two arguments" + + return KafkaTransformedDStream(self, func) + + +class KafkaTransformedDStream(TransformedDStream): + """ + Kafka specific wrapper of TransformedDStream to transform on Kafka RDD. + """ + + def __init__(self, prev, func): + TransformedDStream.__init__(self, prev, func) + + @property + def _jdstream(self): + if self._jdstream_val is not None: + return self._jdstream_val + + jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) \ + .rdd_wrapper(lambda jrdd, ctx, ser: KafkaRDD(jrdd, ctx, ser)) + dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) + self._jdstream_val = dstream.asJavaDStream() + return self._jdstream_val diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 188c8ff12067e..4ecae1e4bf282 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -678,6 +678,70 @@ def test_kafka_rdd_with_leaders(self): rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) self._validateRddResult(sendData, rdd) + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd_get_offsetRanges(self): + """Test Python direct Kafka RDD get OffsetRanges.""" + topic = self._randomTopic() + sendData = {"a": 3, "b": 4, "c": 5} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) + self.assertEqual(offsetRanges, rdd.offsetRanges()) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_foreach_get_offsetRanges(self): + """Test the Python direct Kafka stream foreachRDD get offsetRanges.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) + + offsetRanges = [] + + def getOffsetRanges(_, rdd): + for o in rdd.offsetRanges(): + offsetRanges.append(o) + + stream.foreachRDD(getOffsetRanges) + self.ssc.start() + self.wait_for(offsetRanges, 1) + + self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_transform_get_offsetRanges(self): + """Test the Python direct Kafka stream transform get offsetRanges.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) + + offsetRanges = [] + + def transformWithOffsetRanges(rdd): + for o in rdd.offsetRanges(): + offsetRanges.append(o) + return rdd + + stream.transform(transformWithOffsetRanges).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.wait_for(offsetRanges, 1) + + self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + class FlumeStreamTests(PySparkStreamingTestCase): timeout = 20 # seconds diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index a9bfec2aab8fc..b20613b1283bd 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -37,6 +37,11 @@ def __init__(self, ctx, func, *deserializers): self.ctx = ctx self.func = func self.deserializers = deserializers + self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) + + def rdd_wrapper(self, func): + self._rdd_wrapper = func + return self def call(self, milliseconds, jrdds): try: @@ -51,7 +56,7 @@ def call(self, milliseconds, jrdds): if len(sers) < len(jrdds): sers += (sers[0],) * (len(jrdds) - len(sers)) - rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None + rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None for jrdd, ser in zip(jrdds, sers)] t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) From c9e2ef52bb54f35a904427389dc492d61f29b018 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 9 Jul 2015 14:43:38 -0700 Subject: [PATCH 092/149] [SPARK-7902] [SPARK-6289] [SPARK-8685] [SQL] [PYSPARK] Refactor of serialization for Python DataFrame This PR fix the long standing issue of serialization between Python RDD and DataFrame, it change to using a customized Pickler for InternalRow to enable customized unpickling (type conversion, especially for UDT), now we can support UDT for UDF, cc mengxr . There is no generated `Row` anymore. Author: Davies Liu Closes #7301 from davies/sql_ser and squashes the following commits: 81bef71 [Davies Liu] address comments e9217bd [Davies Liu] add regression tests db34167 [Davies Liu] Refactor of serialization for Python DataFrame --- python/pyspark/sql/context.py | 5 +- python/pyspark/sql/dataframe.py | 16 +- python/pyspark/sql/tests.py | 28 +- python/pyspark/sql/types.py | 419 ++++++------------ .../spark/sql/catalyst/expressions/rows.scala | 12 + .../org/apache/spark/sql/DataFrame.scala | 5 +- .../spark/sql/execution/pythonUDFs.scala | 122 ++++- 7 files changed, 292 insertions(+), 315 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 309c11faf9319..c93a15badae29 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -30,7 +30,7 @@ from pyspark.serializers import AutoBatchedSerializer, PickleSerializer from pyspark.sql import since from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ - _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter + _infer_schema, _has_nulltype, _merge_type, _create_converter from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.utils import install_exception_handler @@ -388,8 +388,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): raise TypeError("schema should be StructType or list or None") # convert python objects to sql data - converter = _python_to_sql_converter(schema) - rdd = rdd.map(converter) + rdd = rdd.map(schema.toInternal) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1e9c657cf81b3..83e02b85f06f1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -31,7 +31,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql import since -from pyspark.sql.types import _create_cls, _parse_datatype_json_string +from pyspark.sql.types import _parse_datatype_json_string from pyspark.sql.column import Column, _to_seq, _to_java_column from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.types import * @@ -83,15 +83,7 @@ def rdd(self): """ if self._lazy_rdd is None: jrdd = self._jdf.javaToPython() - rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) - schema = self.schema - - def applySchema(it): - cls = _create_cls(schema) - return map(cls, it) - - self._lazy_rdd = rdd.mapPartitions(applySchema) - + self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) return self._lazy_rdd @property @@ -287,9 +279,7 @@ def collect(self): """ with SCCallSiteSync(self._sc) as css: port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd()) - rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) - cls = _create_cls(self.schema) - return [cls(r) for r in rs] + return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix @since(1.3) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 66827d48850d9..4d7cad5a1ab88 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -151,6 +151,17 @@ def test_range(self): self.assertEqual(self.sqlCtx.range(-2).count(), 0) self.assertEqual(self.sqlCtx.range(3).count(), 3) + def test_duplicated_column_names(self): + df = self.sqlCtx.createDataFrame([(1, 2)], ["c", "c"]) + row = df.select('*').first() + self.assertEqual(1, row[0]) + self.assertEqual(2, row[1]) + self.assertEqual("Row(c=1, c=2)", str(row)) + # Cannot access columns + self.assertRaises(AnalysisException, lambda: df.select(df[0]).first()) + self.assertRaises(AnalysisException, lambda: df.select(df.c).first()) + self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first()) + def test_explode(self): from pyspark.sql.functions import explode d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] @@ -401,6 +412,14 @@ def test_apply_schema_with_udt(self): point = df.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + def test_udf_with_udt(self): + from pyspark.sql.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df = self.sc.parallelize([row]).toDF() + self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) + udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + def test_parquet_with_udt(self): from pyspark.sql.tests import ExamplePoint row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) @@ -693,12 +712,9 @@ def test_time_with_timezone(self): utcnow = datetime.datetime.fromtimestamp(ts, utc) df = self.sqlCtx.createDataFrame([(day, now, utcnow)]) day1, now1, utcnow1 = df.first() - # Pyrolite serialize java.sql.Date as datetime, will be fixed in new version - self.assertEqual(day1.date(), day) - # Pyrolite does not support microsecond, the error should be - # less than 1 millisecond - self.assertTrue(now - now1 < datetime.timedelta(0.001)) - self.assertTrue(now - utcnow1 < datetime.timedelta(0.001)) + self.assertEqual(day1, day) + self.assertEqual(now, now1) + self.assertEqual(now, utcnow1) def test_decimal(self): from decimal import Decimal diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index fecfe6d71e9a7..d63857691675a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -20,13 +20,9 @@ import time import datetime import calendar -import keyword -import warnings import json import re -import weakref from array import array -from operator import itemgetter if sys.version >= "3": long = int @@ -71,6 +67,26 @@ def json(self): separators=(',', ':'), sort_keys=True) + def needConversion(self): + """ + Does this type need to conversion between Python object and internal SQL object. + + This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType. + """ + return False + + def toInternal(self, obj): + """ + Converts a Python object into an internal SQL object. + """ + return obj + + def fromInternal(self, obj): + """ + Converts an internal SQL object into a native Python object. + """ + return obj + # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle @@ -143,6 +159,17 @@ class DateType(AtomicType): __metaclass__ = DataTypeSingleton + EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() + + def needConversion(self): + return True + + def toInternal(self, d): + return d and d.toordinal() - self.EPOCH_ORDINAL + + def fromInternal(self, v): + return v and datetime.date.fromordinal(v + self.EPOCH_ORDINAL) + class TimestampType(AtomicType): """Timestamp (datetime.datetime) data type. @@ -150,6 +177,19 @@ class TimestampType(AtomicType): __metaclass__ = DataTypeSingleton + def needConversion(self): + return True + + def toInternal(self, dt): + if dt is not None: + seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo + else time.mktime(dt.timetuple())) + return int(seconds * 1e6 + dt.microsecond) + + def fromInternal(self, ts): + if ts is not None: + return datetime.datetime.fromtimestamp(ts / 1e6) + class DecimalType(FractionalType): """Decimal (decimal.Decimal) data type. @@ -259,6 +299,19 @@ def fromJson(cls, json): return ArrayType(_parse_datatype_json_value(json["elementType"]), json["containsNull"]) + def needConversion(self): + return self.elementType.needConversion() + + def toInternal(self, obj): + if not self.needConversion(): + return obj + return obj and [self.elementType.toInternal(v) for v in obj] + + def fromInternal(self, obj): + if not self.needConversion(): + return obj + return obj and [self.elementType.fromInternal(v) for v in obj] + class MapType(DataType): """Map data type. @@ -304,6 +357,21 @@ def fromJson(cls, json): _parse_datatype_json_value(json["valueType"]), json["valueContainsNull"]) + def needConversion(self): + return self.keyType.needConversion() or self.valueType.needConversion() + + def toInternal(self, obj): + if not self.needConversion(): + return obj + return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v)) + for k, v in obj.items()) + + def fromInternal(self, obj): + if not self.needConversion(): + return obj + return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v)) + for k, v in obj.items()) + class StructField(DataType): """A field in :class:`StructType`. @@ -311,7 +379,7 @@ class StructField(DataType): :param name: string, name of the field. :param dataType: :class:`DataType` of the field. :param nullable: boolean, whether the field can be null (None) or not. - :param metadata: a dict from string to simple type that can be serialized to JSON automatically + :param metadata: a dict from string to simple type that can be toInternald to JSON automatically """ def __init__(self, name, dataType, nullable=True, metadata=None): @@ -351,6 +419,15 @@ def fromJson(cls, json): json["nullable"], json["metadata"]) + def needConversion(self): + return self.dataType.needConversion() + + def toInternal(self, obj): + return self.dataType.toInternal(obj) + + def fromInternal(self, obj): + return self.dataType.fromInternal(obj) + class StructType(DataType): """Struct type, consisting of a list of :class:`StructField`. @@ -371,10 +448,13 @@ def __init__(self, fields=None): """ if not fields: self.fields = [] + self.names = [] else: self.fields = fields + self.names = [f.name for f in fields] assert all(isinstance(f, StructField) for f in fields),\ "fields should be a list of StructField" + self._needSerializeFields = None def add(self, field, data_type=None, nullable=True, metadata=None): """ @@ -406,6 +486,7 @@ def add(self, field, data_type=None, nullable=True, metadata=None): """ if isinstance(field, StructField): self.fields.append(field) + self.names.append(field.name) else: if isinstance(field, str) and data_type is None: raise ValueError("Must specify DataType if passing name of struct_field to create.") @@ -415,6 +496,7 @@ def add(self, field, data_type=None, nullable=True, metadata=None): else: data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) + self.names.append(field) return self def simpleString(self): @@ -432,6 +514,41 @@ def jsonValue(self): def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]]) + def needConversion(self): + # We need convert Row()/namedtuple into tuple() + return True + + def toInternal(self, obj): + if obj is None: + return + + if self._needSerializeFields is None: + self._needSerializeFields = any(f.needConversion() for f in self.fields) + + if self._needSerializeFields: + if isinstance(obj, dict): + return tuple(f.toInternal(obj.get(n)) for n, f in zip(names, self.fields)) + elif isinstance(obj, (tuple, list)): + return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) + else: + raise ValueError("Unexpected tuple %r with StructType" % obj) + else: + if isinstance(obj, dict): + return tuple(obj.get(n) for n in self.names) + elif isinstance(obj, (list, tuple)): + return tuple(obj) + else: + raise ValueError("Unexpected tuple %r with StructType" % obj) + + def fromInternal(self, obj): + if obj is None: + return + if isinstance(obj, Row): + # it's already converted by pickler + return obj + values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)] + return _create_row(self.names, values) + class UserDefinedType(DataType): """User-defined type (UDT). @@ -464,17 +581,35 @@ def scalaUDT(cls): """ raise NotImplementedError("UDT must have a paired Scala UDT.") + def needConversion(self): + return True + + @classmethod + def _cachedSqlType(cls): + """ + Cache the sqlType() into class, because it's heavy used in `toInternal`. + """ + if not hasattr(cls, "_cached_sql_type"): + cls._cached_sql_type = cls.sqlType() + return cls._cached_sql_type + + def toInternal(self, obj): + return self._cachedSqlType().toInternal(self.serialize(obj)) + + def fromInternal(self, obj): + return self.deserialize(self._cachedSqlType().fromInternal(obj)) + def serialize(self, obj): """ Converts the a user-type object into a SQL datum. """ - raise NotImplementedError("UDT must implement serialize().") + raise NotImplementedError("UDT must implement toInternal().") def deserialize(self, datum): """ Converts a SQL datum into a user-type object. """ - raise NotImplementedError("UDT must implement deserialize().") + raise NotImplementedError("UDT must implement fromInternal().") def simpleString(self): return 'udt' @@ -671,117 +806,6 @@ def _infer_schema(row): return StructType(fields) -def _need_python_to_sql_conversion(dataType): - """ - Checks whether we need python to sql conversion for the given type. - For now, only UDTs need this conversion. - - >>> _need_python_to_sql_conversion(DoubleType()) - False - >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), - ... StructField("values", ArrayType(DoubleType(), False), False)]) - >>> _need_python_to_sql_conversion(schema0) - True - >>> _need_python_to_sql_conversion(ExamplePointUDT()) - True - >>> schema1 = ArrayType(ExamplePointUDT(), False) - >>> _need_python_to_sql_conversion(schema1) - True - >>> schema2 = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> _need_python_to_sql_conversion(schema2) - True - """ - if isinstance(dataType, StructType): - # convert namedtuple or Row into tuple - return True - elif isinstance(dataType, ArrayType): - return _need_python_to_sql_conversion(dataType.elementType) - elif isinstance(dataType, MapType): - return _need_python_to_sql_conversion(dataType.keyType) or \ - _need_python_to_sql_conversion(dataType.valueType) - elif isinstance(dataType, UserDefinedType): - return True - elif isinstance(dataType, (DateType, TimestampType)): - return True - else: - return False - - -EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() - - -def _python_to_sql_converter(dataType): - """ - Returns a converter that converts a Python object into a SQL datum for the given type. - - >>> conv = _python_to_sql_converter(DoubleType()) - >>> conv(1.0) - 1.0 - >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) - >>> conv([1.0, 2.0]) - [1.0, 2.0] - >>> conv = _python_to_sql_converter(ExamplePointUDT()) - >>> conv(ExamplePoint(1.0, 2.0)) - [1.0, 2.0] - >>> schema = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> conv = _python_to_sql_converter(schema) - >>> conv((1.0, ExamplePoint(1.0, 2.0))) - (1.0, [1.0, 2.0]) - """ - if not _need_python_to_sql_conversion(dataType): - return lambda x: x - - if isinstance(dataType, StructType): - names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) - if any(_need_python_to_sql_conversion(t) for t in types): - converters = [_python_to_sql_converter(t) for t in types] - - def converter(obj): - if isinstance(obj, dict): - return tuple(c(obj.get(n)) for n, c in zip(names, converters)) - elif isinstance(obj, tuple): - if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): - return tuple(c(v) for c, v in zip(converters, obj)) - else: - return tuple(c(v) for c, v in zip(converters, obj)) - elif obj is not None: - raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) - else: - def converter(obj): - if isinstance(obj, dict): - return tuple(obj.get(n) for n in names) - else: - return tuple(obj) - return converter - elif isinstance(dataType, ArrayType): - element_converter = _python_to_sql_converter(dataType.elementType) - return lambda a: a and [element_converter(v) for v in a] - elif isinstance(dataType, MapType): - key_converter = _python_to_sql_converter(dataType.keyType) - value_converter = _python_to_sql_converter(dataType.valueType) - return lambda m: m and dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) - - elif isinstance(dataType, UserDefinedType): - return lambda obj: obj and dataType.serialize(obj) - - elif isinstance(dataType, DateType): - return lambda d: d and d.toordinal() - EPOCH_ORDINAL - - elif isinstance(dataType, TimestampType): - - def to_posix_timstamp(dt): - if dt: - seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo - else time.mktime(dt.timetuple())) - return int(seconds * 1e6 + dt.microsecond) - return to_posix_timstamp - - else: - raise ValueError("Unexpected type %r" % dataType) - - def _has_nulltype(dt): """ Return whether there is NullType in `dt` or not """ if isinstance(dt, StructType): @@ -1076,7 +1100,7 @@ def _verify_type(obj, dataType): if isinstance(dataType, UserDefinedType): if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.serialize(obj), dataType.sqlType()) + _verify_type(dataType.toInternal(obj), dataType.sqlType()) return _type = type(dataType) @@ -1086,7 +1110,7 @@ def _verify_type(obj, dataType): if not isinstance(obj, (tuple, list)): raise TypeError("StructType can not accept object in type %s" % type(obj)) else: - # subclass of them can not be deserialized in JVM + # subclass of them can not be fromInternald in JVM if type(obj) not in _acceptable_types[_type]: raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) @@ -1106,159 +1130,10 @@ def _verify_type(obj, dataType): for v, f in zip(obj, dataType.fields): _verify_type(v, f.dataType) -_cached_cls = weakref.WeakValueDictionary() - - -def _restore_object(dataType, obj): - """ Restore object during unpickling. """ - # use id(dataType) as key to speed up lookup in dict - # Because of batched pickling, dataType will be the - # same object in most cases. - k = id(dataType) - cls = _cached_cls.get(k) - if cls is None or cls.__datatype is not dataType: - # use dataType as key to avoid create multiple class - cls = _cached_cls.get(dataType) - if cls is None: - cls = _create_cls(dataType) - _cached_cls[dataType] = cls - cls.__datatype = dataType - _cached_cls[k] = cls - return cls(obj) - - -def _create_object(cls, v): - """ Create an customized object with class `cls`. """ - # datetime.date would be deserialized as datetime.datetime - # from java type, so we need to set it back. - if cls is datetime.date and isinstance(v, datetime.datetime): - return v.date() - return cls(v) if v is not None else v - - -def _create_getter(dt, i): - """ Create a getter for item `i` with schema """ - cls = _create_cls(dt) - - def getter(self): - return _create_object(cls, self[i]) - - return getter - - -def _has_struct_or_date(dt): - """Return whether `dt` is or has StructType/DateType in it""" - if isinstance(dt, StructType): - return True - elif isinstance(dt, ArrayType): - return _has_struct_or_date(dt.elementType) - elif isinstance(dt, MapType): - return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) - elif isinstance(dt, DateType): - return True - elif isinstance(dt, UserDefinedType): - return True - return False - - -def _create_properties(fields): - """Create properties according to fields""" - ps = {} - for i, f in enumerate(fields): - name = f.name - if (name.startswith("__") and name.endswith("__") - or keyword.iskeyword(name)): - warnings.warn("field name %s can not be accessed in Python," - "use position to access it instead" % name) - if _has_struct_or_date(f.dataType): - # delay creating object until accessing it - getter = _create_getter(f.dataType, i) - else: - getter = itemgetter(i) - ps[name] = property(getter) - return ps - - -def _create_cls(dataType): - """ - Create an class by dataType - - The created class is similar to namedtuple, but can have nested schema. - - >>> schema = _parse_schema_abstract("a b c") - >>> row = (1, 1.0, "str") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> import pickle - >>> pickle.loads(pickle.dumps(obj)) - Row(a=1, b=1.0, c='str') - - >>> row = [[1], {"key": (1, 2.0)}] - >>> schema = _parse_schema_abstract("a[] b{c d}") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> pickle.loads(pickle.dumps(obj)) - Row(a=[1], b={'key': Row(c=1, d=2.0)}) - >>> pickle.loads(pickle.dumps(obj.a)) - [1] - >>> pickle.loads(pickle.dumps(obj.b)) - {'key': Row(c=1, d=2.0)} - """ - - if isinstance(dataType, ArrayType): - cls = _create_cls(dataType.elementType) - - def List(l): - if l is None: - return - return [_create_object(cls, v) for v in l] - - return List - - elif isinstance(dataType, MapType): - kcls = _create_cls(dataType.keyType) - vcls = _create_cls(dataType.valueType) - - def Dict(d): - if d is None: - return - return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) - - return Dict - - elif isinstance(dataType, DateType): - return datetime.date - - elif isinstance(dataType, UserDefinedType): - return lambda datum: dataType.deserialize(datum) - - elif not isinstance(dataType, StructType): - # no wrapper for atomic types - return lambda x: x - - class Row(tuple): - - """ Row in DataFrame """ - __datatype = dataType - __fields__ = tuple(f.name for f in dataType.fields) - __slots__ = () - - # create property for fast access - locals().update(_create_properties(dataType.fields)) - - def asDict(self): - """ Return as a dict """ - return dict((n, getattr(self, n)) for n in self.__fields__) - - def __repr__(self): - # call collect __repr__ for nested objects - return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self.__fields__)) - - def __reduce__(self): - return (_restore_object, (self.__datatype, tuple(self))) - return Row +# This is used to unpickle a Row from JVM +def _create_row_inbound_converter(dataType): + return lambda *a: dataType.fromInternal(a) def _create_row(fields, values): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 8b472a529e5c9..094904bbf9c15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -132,6 +132,18 @@ class GenericInternalRow(protected[sql] val values: Array[Any]) override def copy(): InternalRow = this } +/** + * This is used for serialization of Python DataFrame + */ +class GenericInternalRowWithSchema(values: Array[Any], override val schema: StructType) + extends GenericInternalRow(values) { + + /** No-arg constructor for serialization. */ + protected def this() = this(null, null) + + override def fieldIndex(name: String): Int = schema.fieldIndex(name) +} + class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBackedRow { /** No-arg constructor for serialization. */ protected def this() = this(null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d9f987ae0252f..d7966651b1948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -30,7 +30,6 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ @@ -1550,8 +1549,8 @@ class DataFrame private[sql]( */ protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { val structType = schema // capture it for closure - val jrdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)).toJavaRDD() - SerDeUtil.javaToPython(jrdd) + val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)) + EvaluatePython.javaToPython(rdd) } //////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 1c8130b07c7fb..6d6e67dace177 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -17,15 +17,16 @@ package org.apache.spark.sql.execution +import java.io.OutputStream import java.util.{List => JList, Map => JMap} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ -import net.razorvine.pickle.{Pickler, Unpickler} +import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.{PythonBroadcast, PythonRDD} +import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -33,7 +34,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.{Accumulator, Logging => SparkLogging} @@ -130,8 +130,13 @@ object EvaluatePython { case (null, _) => null case (row: InternalRow, struct: StructType) => - val fields = struct.fields.map(field => field.dataType) - rowToArray(row, fields) + val values = new Array[Any](row.size) + var i = 0 + while (i < row.size) { + values(i) = toJava(row(i), struct.fields(i).dataType) + i += 1 + } + new GenericInternalRowWithSchema(values, struct) case (seq: Seq[Any], array: ArrayType) => seq.map(x => toJava(x, array.elementType)).asJava @@ -142,9 +147,6 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType) - case (date: Int, DateType) => DateTimeUtils.toJavaDate(date) - case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t) - case (d: Decimal, _) => d.toJavaBigDecimal case (s: UTF8String, StringType) => s.toString @@ -152,14 +154,6 @@ object EvaluatePython { case (other, _) => other } - /** - * Convert Row into Java Array (for pickled into Python) - */ - def rowToArray(row: InternalRow, fields: Seq[DataType]): Array[Any] = { - // TODO: this is slow! - row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray - } - /** * Converts `obj` to the type specified by the data type, or returns null if the type of obj is * unexpected. Because Python doesn't enforce the type. @@ -220,6 +214,96 @@ object EvaluatePython { // TODO(davies): we could improve this by try to cast the object to expected type case (c, _) => null } + + + private val module = "pyspark.sql.types" + + /** + * Pickler for StructType + */ + private class StructTypePickler extends IObjectPickler { + + private val cls = classOf[StructType] + + def register(): Unit = { + Pickler.registerCustomPickler(cls, this) + } + + def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + out.write(Opcodes.GLOBAL) + out.write((module + "\n" + "_parse_datatype_json_string" + "\n").getBytes("utf-8")) + val schema = obj.asInstanceOf[StructType] + pickler.save(schema.json) + out.write(Opcodes.TUPLE1) + out.write(Opcodes.REDUCE) + } + } + + /** + * Pickler for InternalRow + */ + private class RowPickler extends IObjectPickler { + + private val cls = classOf[GenericInternalRowWithSchema] + + // register this to Pickler and Unpickler + def register(): Unit = { + Pickler.registerCustomPickler(this.getClass, this) + Pickler.registerCustomPickler(cls, this) + } + + def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + if (obj == this) { + out.write(Opcodes.GLOBAL) + out.write((module + "\n" + "_create_row_inbound_converter" + "\n").getBytes("utf-8")) + } else { + // it will be memorized by Pickler to save some bytes + pickler.save(this) + val row = obj.asInstanceOf[GenericInternalRowWithSchema] + // schema should always be same object for memoization + pickler.save(row.schema) + out.write(Opcodes.TUPLE1) + out.write(Opcodes.REDUCE) + + out.write(Opcodes.MARK) + var i = 0 + while (i < row.values.size) { + pickler.save(row.values(i)) + i += 1 + } + row.values.foreach(pickler.save) + out.write(Opcodes.TUPLE) + out.write(Opcodes.REDUCE) + } + } + } + + private[this] var registered = false + /** + * This should be called before trying to serialize any above classes un cluster mode, + * this should be put in the closure + */ + def registerPicklers(): Unit = { + synchronized { + if (!registered) { + SerDeUtil.initialize() + new StructTypePickler().register() + new RowPickler().register() + registered = true + } + } + } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = { + rdd.mapPartitions { iter => + registerPicklers() // let it called in executor + new SerDeUtil.AutoBatchedPickler(iter) + } + } } /** @@ -254,12 +338,14 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val childResults = child.execute().map(_.copy()) val parent = childResults.mapPartitions { iter => + EvaluatePython.registerPicklers() // register pickler for Row val pickle = new Pickler val currentRow = newMutableProjection(udf.children, child.output)() val fields = udf.children.map(_.dataType) - iter.grouped(1000).map { inputRows => + val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) + iter.grouped(100).map { inputRows => val toBePickled = inputRows.map { row => - EvaluatePython.rowToArray(currentRow(row), fields) + EvaluatePython.toJava(currentRow(row), schema) }.toArray pickle.dumps(toBePickled) } From 897700369f3aedf1a8fdb0984dd3d6d8e498e3af Mon Sep 17 00:00:00 2001 From: guowei2 Date: Thu, 9 Jul 2015 15:01:53 -0700 Subject: [PATCH 093/149] [SPARK-8865] [STREAMING] FIX BUG: check key in kafka params Author: guowei2 Closes #7254 from guowei2/spark-8865 and squashes the following commits: 48ca17a [guowei2] fix contains key --- .../scala/org/apache/spark/streaming/kafka/KafkaCluster.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 3e6b937af57b0..8465432c5850f 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -410,7 +410,7 @@ object KafkaCluster { } Seq("zookeeper.connect", "group.id").foreach { s => - if (!props.contains(s)) { + if (!props.containsKey(s)) { props.setProperty(s, "") } } From 69165330303a71ea1da748eca7a780ec172b326f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 9 Jul 2015 15:14:14 -0700 Subject: [PATCH 094/149] Closes #6837 Closes #7321 Closes #2634 Closes #4963 Closes #2137 From e29ce319fa6ffb9c8e5110814d4923d433aa1b76 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 9 Jul 2015 15:49:30 -0700 Subject: [PATCH 095/149] [SPARK-8963][ML] cleanup tests in linear regression suite Simplify model weight assertions to use vector comparision, switch to using absTol when comparing with 0.0 intercepts Author: Holden Karau Closes #7327 from holdenk/SPARK-8913-cleanup-tests-from-SPARK-8700-logistic-regression and squashes the following commits: 5bac185 [Holden Karau] Simplify model weight assertions to use vector comparision, switch to using absTol when comparing with 0.0 intercepts --- .../ml/regression/LinearRegressionSuite.scala | 57 ++++++++----------- 1 file changed, 24 insertions(+), 33 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 5f39d44f37352..4f6a57739558b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.DenseVector +import org.apache.spark.mllib.linalg.{DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -75,11 +75,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 7.198257 */ val interceptR = 6.298698 - val weightsR = Array(4.700706, 7.199082) + val weightsR = Vectors.dense(4.700706, 7.199082) assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.weights ~= weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -104,11 +103,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V2. 6.995908 as.numeric.data.V3. 5.275131 */ - val weightsR = Array(6.995908, 5.275131) + val weightsR = Vectors.dense(6.995908, 5.275131) - assert(model.intercept ~== 0 relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.intercept ~== 0 absTol 1E-3) + assert(model.weights ~= weightsR relTol 1E-3) /* Then again with the data with no intercept: > weightsWithoutIntercept @@ -118,11 +116,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data3.V2. 4.70011 as.numeric.data3.V3. 7.19943 */ - val weightsWithoutInterceptR = Array(4.70011, 7.19943) + val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943) - assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3) - assert(modelWithoutIntercept.weights(0) ~== weightsWithoutInterceptR(0) relTol 1E-3) - assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3) + assert(modelWithoutIntercept.intercept ~== 0 absTol 1E-3) + assert(modelWithoutIntercept.weights ~= weightsWithoutInterceptR relTol 1E-3) } test("linear regression with intercept with L1 regularization") { @@ -139,11 +136,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 6.679841 */ val interceptR = 6.24300 - val weightsR = Array(4.024821, 6.679841) + val weightsR = Vectors.dense(4.024821, 6.679841) assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.weights ~= weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -169,11 +165,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 4.772913 */ val interceptR = 0.0 - val weightsR = Array(6.299752, 4.772913) + val weightsR = Vectors.dense(6.299752, 4.772913) - assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-5) + assert(model.weights ~= weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -197,11 +192,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 4.926260 */ val interceptR = 5.269376 - val weightsR = Array(3.736216, 5.712356) + val weightsR = Vectors.dense(3.736216, 5.712356) assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.weights ~= weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -227,11 +221,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 4.214502 */ val interceptR = 0.0 - val weightsR = Array(5.522875, 4.214502) + val weightsR = Vectors.dense(5.522875, 4.214502) - assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + assert(model.weights ~== weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -255,11 +248,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 5.200403 */ val interceptR = 5.696056 - val weightsR = Array(3.670489, 6.001122) + val weightsR = Vectors.dense(3.670489, 6.001122) assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.weights ~== weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -285,11 +277,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.dataM.V3. 4.322251 */ val interceptR = 0.0 - val weightsR = Array(5.673348, 4.322251) + val weightsR = Vectors.dense(5.673348, 4.322251) - assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + assert(model.weights ~= weightsR relTol 1E-3) model.transform(dataset).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => From a0cc3e5aa3fcfd0fce6813c520152657d327aaf2 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 9 Jul 2015 16:21:21 -0700 Subject: [PATCH 096/149] [SPARK-8538] [SPARK-8539] [ML] Linear Regression Training and Testing Results Adds results (e.g. objective value at each iteration, residuals) on training and user-specified test sets for LinearRegressionModel. Notes to Reviewers: * Are the `*TrainingResults` and `Results` classes too specialized for `LinearRegressionModel`? Where would be an appropriate level of abstraction? * Please check `transient` annotations are correct; the datasets should not be copied and kept during serialization. * Any thoughts on `RDD`s versus `DataFrame`s? If using `DataFrame`s, suggested schemas for each intermediate step? Also, how to create a "local DataFrame" without a `sqlContext`? Author: Feynman Liang Closes #7099 from feynmanliang/SPARK-8538 and squashes the following commits: d219fa4 [Feynman Liang] Update docs 4a42680 [Feynman Liang] Change Summary to hold values, move transient annotations down to metrics and predictions DF 6300031 [Feynman Liang] Code review changes 0a5e762 [Feynman Liang] Fix build error e71102d [Feynman Liang] Merge branch 'master' into SPARK-8538 3367489 [Feynman Liang] Merge branch 'master' into SPARK-8538 70f267c [Feynman Liang] Make TrainingSummary transient and remove Serializable from *Summary and RegressionMetrics 1d9ea42 [Feynman Liang] Fix failing Java test a65dfda [Feynman Liang] Make TrainingSummary and metrics serializable, prediction dataframe transient 0a605d8 [Feynman Liang] Replace Params from LinearRegression*Summary with private constructor vals c2fe835 [Feynman Liang] Optimize imports 02d8a70 [Feynman Liang] Add Params to LinearModel*Summary, refactor tests and add test for evaluate() 8f999f4 [Feynman Liang] Refactor from jkbradley code review 072e948 [Feynman Liang] Style 509ae36 [Feynman Liang] Use DFs and localize serialization to LinearRegressionModel 9509c79 [Feynman Liang] Fix imports b2bbaa3 [Feynman Liang] Refactored LinearRegressionResults API to be more private ffceaec [Feynman Liang] Merge branch 'master' into SPARK-8538 1cedb2b [Feynman Liang] Add test for decreasing objective trace dab0aff [Feynman Liang] Add LinearRegressionTrainingResults tests, make test suite code copy+pasteable 97b0a81 [Feynman Liang] Add LinearRegressionModel.evaluate() to get results on test sets dc51bce [Feynman Liang] Style guide fixes 521f397 [Feynman Liang] Use RDD[(Double, Double)] instead of DF 2ff5710 [Feynman Liang] Add training results and model summary to ML LinearRegression --- .../ml/regression/LinearRegression.scala | 139 +++++++++++++++++- .../ml/regression/LinearRegressionSuite.scala | 59 ++++++++ 2 files changed, 192 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f672c96576a33..8fc986056657d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -22,18 +22,20 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV, norm => brzNorm} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.Experimental import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.StatCounter @@ -139,7 +141,16 @@ class LinearRegression(override val uid: String) logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " + s"and the intercept will be the mean of the label; as a result, training is not needed.") if (handlePersistence) instances.unpersist() - return new LinearRegressionModel(uid, Vectors.sparse(numFeatures, Seq()), yMean) + val weights = Vectors.sparse(numFeatures, Seq()) + val intercept = yMean + + val model = new LinearRegressionModel(uid, weights, intercept) + val trainingSummary = new LinearRegressionTrainingSummary( + model.transform(dataset).select($(predictionCol), $(labelCol)), + $(predictionCol), + $(labelCol), + Array(0D)) + return copyValues(model.setSummary(trainingSummary)) } val featuresMean = summarizer.mean.toArray @@ -178,7 +189,6 @@ class LinearRegression(override val uid: String) state = states.next() arrayBuilder += state.adjustedValue } - if (state == null) { val msg = s"${optimizer.getClass.getName} failed." logError(msg) @@ -209,7 +219,13 @@ class LinearRegression(override val uid: String) if (handlePersistence) instances.unpersist() - copyValues(new LinearRegressionModel(uid, weights, intercept)) + val model = copyValues(new LinearRegressionModel(uid, weights, intercept)) + val trainingSummary = new LinearRegressionTrainingSummary( + model.transform(dataset).select($(predictionCol), $(labelCol)), + $(predictionCol), + $(labelCol), + objectiveHistory) + model.setSummary(trainingSummary) } override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) @@ -227,13 +243,124 @@ class LinearRegressionModel private[ml] ( extends RegressionModel[Vector, LinearRegressionModel] with LinearRegressionParams { + private var trainingSummary: Option[LinearRegressionTrainingSummary] = None + + /** + * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + def summary: LinearRegressionTrainingSummary = trainingSummary match { + case Some(summ) => summ + case None => + throw new SparkException( + "No training summary available for this LinearRegressionModel", + new NullPointerException()) + } + + private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** Indicates whether a training summary exists for this model instance. */ + def hasSummary: Boolean = trainingSummary.isDefined + + /** + * Evaluates the model on a testset. + * @param dataset Test dataset to evaluate model on. + */ + // TODO: decide on a good name before exposing to public API + private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { + val t = udf { features: Vector => predict(features) } + val predictionAndObservations = dataset + .select(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol))) + + new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol)) + } + override protected def predict(features: Vector): Double = { dot(features, weights) + intercept } override def copy(extra: ParamMap): LinearRegressionModel = { - copyValues(new LinearRegressionModel(uid, weights, intercept), extra) + val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept)) + if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) + newModel + } +} + +/** + * :: Experimental :: + * Linear regression training results. + * @param predictions predictions outputted by the model's `transform` method. + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. + */ +@Experimental +class LinearRegressionTrainingSummary private[regression] ( + predictions: DataFrame, + predictionCol: String, + labelCol: String, + val objectiveHistory: Array[Double]) + extends LinearRegressionSummary(predictions, predictionCol, labelCol) { + + /** Number of training iterations until termination */ + val totalIterations = objectiveHistory.length + +} + +/** + * :: Experimental :: + * Linear regression results evaluated on a dataset. + * @param predictions predictions outputted by the model's `transform` method. + */ +@Experimental +class LinearRegressionSummary private[regression] ( + @transient val predictions: DataFrame, + val predictionCol: String, + val labelCol: String) extends Serializable { + + @transient private val metrics = new RegressionMetrics( + predictions + .select(predictionCol, labelCol) + .map { case Row(pred: Double, label: Double) => (pred, label) } ) + + /** + * Returns the explained variance regression score. + * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + */ + val explainedVariance: Double = metrics.explainedVariance + + /** + * Returns the mean absolute error, which is a risk function corresponding to the + * expected value of the absolute error loss or l1-norm loss. + */ + val meanAbsoluteError: Double = metrics.meanAbsoluteError + + /** + * Returns the mean squared error, which is a risk function corresponding to the + * expected value of the squared error loss or quadratic loss. + */ + val meanSquaredError: Double = metrics.meanSquaredError + + /** + * Returns the root mean squared error, which is defined as the square root of + * the mean squared error. + */ + val rootMeanSquaredError: Double = metrics.rootMeanSquaredError + + /** + * Returns R^2^, the coefficient of determination. + * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + */ + val r2: Double = metrics.r2 + + /** Residuals (predicted value - label value) */ + @transient lazy val residuals: DataFrame = { + val t = udf { (pred: Double, label: Double) => pred - label} + predictions.select(t(col(predictionCol), col(labelCol)).as("residuals")) } + } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 4f6a57739558b..cf120cf2a4b47 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -289,4 +289,63 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(prediction1 ~== prediction2 relTol 1E-5) } } + + test("linear regression model training summary") { + val trainer = new LinearRegression + val model = trainer.fit(dataset) + + // Training results for the model should be available + assert(model.hasSummary) + + // Residuals in [[LinearRegressionResults]] should equal those manually computed + val expectedResiduals = dataset.select("features", "label") + .map { case Row(features: DenseVector, label: Double) => + val prediction = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + prediction - label + } + .zip(model.summary.residuals.map(_.getDouble(0))) + .collect() + .foreach { case (manualResidual: Double, resultResidual: Double) => + assert(manualResidual ~== resultResidual relTol 1E-5) + } + + /* + Use the following R code to generate model training results. + + predictions <- predict(fit, newx=features) + residuals <- predictions - label + > mean(residuals^2) # MSE + [1] 0.009720325 + > mean(abs(residuals)) # MAD + [1] 0.07863206 + > cor(predictions, label)^2# r^2 + [,1] + s0 0.9998749 + */ + assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5) + assert(model.summary.meanAbsoluteError ~== 0.07863206 relTol 1E-5) + assert(model.summary.r2 ~== 0.9998749 relTol 1E-5) + + // Objective function should be monotonically decreasing for linear regression + assert( + model.summary + .objectiveHistory + .sliding(2) + .forall(x => x(0) >= x(1))) + } + + test("linear regression model testset evaluation summary") { + val trainer = new LinearRegression + val model = trainer.fit(dataset) + + // Evaluating on training dataset should yield results summary equal to training summary + val testSummary = model.evaluate(dataset) + assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5) + assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5) + model.summary.residuals.select("residuals").collect() + .zip(testSummary.residuals.select("residuals").collect()) + .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 } + } + } From 2d45571fcb002cc9f03056c5a3f14493b83315a4 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 9 Jul 2015 17:09:16 -0700 Subject: [PATCH 097/149] [SPARK-8959] [SQL] [HOTFIX] Removes parquet-thrift and libthrift dependencies These two dependencies were introduced in #7231 to help testing Parquet compatibility with `parquet-thrift`. However, they somehow crash the Scala compiler in Maven builds. This PR fixes this issue by: 1. Removing these two dependencies, and 2. Instead of generating the testing Parquet file programmatically, checking in an actual testing Parquet file generated by `parquet-thrift` as a test resource. This is just a quick fix to bring back Maven builds. Need to figure out the root case as binary Parquet files are harder to maintain. Author: Cheng Lian Closes #7330 from liancheng/spark-8959 and squashes the following commits: cf69512 [Cheng Lian] Brings back Maven builds --- pom.xml | 14 - sql/core/pom.xml | 10 - .../spark/sql/parquet/test/thrift/Nested.java | 541 ---- .../test/thrift/ParquetThriftCompat.java | 2808 ----------------- .../spark/sql/parquet/test/thrift/Suit.java | 51 - .../parquet-thrift-compat.snappy.parquet | Bin 0 -> 10550 bytes .../ParquetThriftCompatibilitySuite.scala | 78 +- 7 files changed, 8 insertions(+), 3494 deletions(-) delete mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Nested.java delete mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/ParquetThriftCompat.java delete mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/thrift/Suit.java create mode 100755 sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet diff --git a/pom.xml b/pom.xml index 529e47f8b5253..1eda108dc065b 100644 --- a/pom.xml +++ b/pom.xml @@ -161,7 +161,6 @@ 2.4.4 1.1.1.7 1.1.2 - 0.9.2 false @@ -181,7 +180,6 @@ compile compile test - test + + commons-codec + commons-codec + provided + + + commons-net + commons-net + provided + + + com.google.protobuf + protobuf-java + provided + org.apache.avro avro - ${avro.version} + provided org.apache.avro avro-ipc - ${avro.version} - - - io.netty - netty - - - org.mortbay.jetty - jetty - - - org.mortbay.jetty - jetty-util - - - org.mortbay.jetty - servlet-api - - - org.apache.velocity - velocity - - + provided + + + org.scala-lang + scala-library + provided - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-shade-plugin - - false - ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-flume-assembly-${project.version}.jar - - - *:* - - - - - *:* - - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - - package - - shade - - - - - - reference.conf - - - log4j.properties - - - - - - - - - - + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-flume-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + + + flume-provided + + provided + + + diff --git a/pom.xml b/pom.xml index 1eda108dc065b..172fdef4c73da 100644 --- a/pom.xml +++ b/pom.xml @@ -1130,6 +1130,10 @@ io.netty netty + + org.apache.flume + flume-ng-auth + org.apache.thrift libthrift From 2727304660663fcf1e41f7b666978c1443262e4e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 9 Jul 2015 19:08:33 -0700 Subject: [PATCH 099/149] [SPARK-8913] [ML] Simplify LogisticRegression suite to use Vector Vector comparision Cleanup tests from SPARK 8700. Author: Holden Karau Closes #7335 from holdenk/SPARK-8913-cleanup-tests-from-SPARK-8700-logistic-regression-r2-really-logistic-regression-this-time and squashes the following commits: e5e2c5f [Holden Karau] Simplify LogisticRegression suite to use Vector <-> Vector comparisions instead of comparing element by element --- .../LogisticRegressionSuite.scala | 135 +++++------------- 1 file changed, 39 insertions(+), 96 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 27253c1db2fff..b7dd44753896a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -234,20 +234,14 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.7996864 */ val interceptR = 2.8366423 - val weightsR = Array(-0.5895848, 0.8931147, -0.3925051, -0.7996864) + val weightsR = Vectors.dense(-0.5895848, 0.8931147, -0.3925051, -0.7996864) assert(model1.intercept ~== interceptR relTol 1E-3) - assert(model1.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model1.weights(1) ~== weightsR(1) relTol 1E-3) - assert(model1.weights(2) ~== weightsR(2) relTol 1E-3) - assert(model1.weights(3) ~== weightsR(3) relTol 1E-3) + assert(model1.weights ~= weightsR relTol 1E-3) // Without regularization, with or without standardization will converge to the same solution. assert(model2.intercept ~== interceptR relTol 1E-3) - assert(model2.weights(0) ~== weightsR(0) relTol 1E-3) - assert(model2.weights(1) ~== weightsR(1) relTol 1E-3) - assert(model2.weights(2) ~== weightsR(2) relTol 1E-3) - assert(model2.weights(3) ~== weightsR(3) relTol 1E-3) + assert(model2.weights ~= weightsR relTol 1E-3) } test("binary logistic regression without intercept without regularization") { @@ -277,20 +271,14 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.7407946 */ val interceptR = 0.0 - val weightsR = Array(-0.3534996, 1.2964482, -0.3571741, -0.7407946) + val weightsR = Vectors.dense(-0.3534996, 1.2964482, -0.3571741, -0.7407946) assert(model1.intercept ~== interceptR relTol 1E-3) - assert(model1.weights(0) ~== weightsR(0) relTol 1E-2) - assert(model1.weights(1) ~== weightsR(1) relTol 1E-2) - assert(model1.weights(2) ~== weightsR(2) relTol 1E-3) - assert(model1.weights(3) ~== weightsR(3) relTol 1E-3) + assert(model1.weights ~= weightsR relTol 1E-2) // Without regularization, with or without standardization should converge to the same solution. assert(model2.intercept ~== interceptR relTol 1E-3) - assert(model2.weights(0) ~== weightsR(0) relTol 1E-2) - assert(model2.weights(1) ~== weightsR(1) relTol 1E-2) - assert(model2.weights(2) ~== weightsR(2) relTol 1E-3) - assert(model2.weights(3) ~== weightsR(3) relTol 1E-3) + assert(model2.weights ~= weightsR relTol 1E-2) } test("binary logistic regression with intercept with L1 regularization") { @@ -321,13 +309,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.02481551 */ val interceptR1 = -0.05627428 - val weightsR1 = Array(0.0, 0.0, -0.04325749, -0.02481551) + val weightsR1 = Vectors.dense(0.0, 0.0, -0.04325749, -0.02481551) assert(model1.intercept ~== interceptR1 relTol 1E-2) - assert(model1.weights(0) ~== weightsR1(0) absTol 1E-3) - assert(model1.weights(1) ~== weightsR1(1) absTol 1E-3) - assert(model1.weights(2) ~== weightsR1(2) relTol 1E-2) - assert(model1.weights(3) ~== weightsR1(3) relTol 2E-2) + assert(model1.weights ~= weightsR1 absTol 2E-2) /* Using the following R code to load the data and train the model using glmnet package. @@ -349,13 +334,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 . */ val interceptR2 = 0.3722152 - val weightsR2 = Array(0.0, 0.0, -0.1665453, 0.0) + val weightsR2 = Vectors.dense(0.0, 0.0, -0.1665453, 0.0) assert(model2.intercept ~== interceptR2 relTol 1E-2) - assert(model2.weights(0) ~== weightsR2(0) absTol 1E-3) - assert(model2.weights(1) ~== weightsR2(1) absTol 1E-3) - assert(model2.weights(2) ~== weightsR2(2) relTol 1E-2) - assert(model2.weights(3) ~== weightsR2(3) absTol 1E-3) + assert(model2.weights ~= weightsR2 absTol 1E-3) } test("binary logistic regression without intercept with L1 regularization") { @@ -387,13 +369,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.03891782 */ val interceptR1 = 0.0 - val weightsR1 = Array(0.0, 0.0, -0.05189203, -0.03891782) + val weightsR1 = Vectors.dense(0.0, 0.0, -0.05189203, -0.03891782) assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights(0) ~== weightsR1(0) absTol 1E-3) - assert(model1.weights(1) ~== weightsR1(1) absTol 1E-3) - assert(model1.weights(2) ~== weightsR1(2) relTol 1E-2) - assert(model1.weights(3) ~== weightsR1(3) relTol 1E-2) + assert(model1.weights ~= weightsR1 absTol 1E-3) /* Using the following R code to load the data and train the model using glmnet package. @@ -415,13 +394,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 . */ val interceptR2 = 0.0 - val weightsR2 = Array(0.0, 0.0, -0.08420782, 0.0) + val weightsR2 = Vectors.dense(0.0, 0.0, -0.08420782, 0.0) - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.weights(0) ~== weightsR2(0) absTol 1E-3) - assert(model2.weights(1) ~== weightsR2(1) absTol 1E-3) - assert(model2.weights(2) ~== weightsR2(2) relTol 1E-2) - assert(model2.weights(3) ~== weightsR2(3) absTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= weightsR2 absTol 1E-3) } test("binary logistic regression with intercept with L2 regularization") { @@ -452,13 +428,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.10062872 */ val interceptR1 = 0.15021751 - val weightsR1 = Array(-0.07251837, 0.10724191, -0.04865309, -0.10062872) + val weightsR1 = Vectors.dense(-0.07251837, 0.10724191, -0.04865309, -0.10062872) assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights(0) ~== weightsR1(0) relTol 1E-3) - assert(model1.weights(1) ~== weightsR1(1) relTol 1E-3) - assert(model1.weights(2) ~== weightsR1(2) relTol 1E-3) - assert(model1.weights(3) ~== weightsR1(3) relTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-3) /* Using the following R code to load the data and train the model using glmnet package. @@ -480,13 +453,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.06266838 */ val interceptR2 = 0.48657516 - val weightsR2 = Array(-0.05155371, 0.02301057, -0.11482896, -0.06266838) + val weightsR2 = Vectors.dense(-0.05155371, 0.02301057, -0.11482896, -0.06266838) assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.weights(0) ~== weightsR2(0) relTol 1E-3) - assert(model2.weights(1) ~== weightsR2(1) relTol 1E-3) - assert(model2.weights(2) ~== weightsR2(2) relTol 1E-3) - assert(model2.weights(3) ~== weightsR2(3) relTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-3) } test("binary logistic regression without intercept with L2 regularization") { @@ -518,13 +488,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.09799775 */ val interceptR1 = 0.0 - val weightsR1 = Array(-0.06099165, 0.12857058, -0.04708770, -0.09799775) + val weightsR1 = Vectors.dense(-0.06099165, 0.12857058, -0.04708770, -0.09799775) - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights(0) ~== weightsR1(0) relTol 1E-2) - assert(model1.weights(1) ~== weightsR1(1) relTol 1E-2) - assert(model1.weights(2) ~== weightsR1(2) relTol 1E-3) - assert(model1.weights(3) ~== weightsR1(3) relTol 1E-3) + assert(model1.intercept ~== interceptR1 absTol 1E-3) + assert(model1.weights ~= weightsR1 relTol 1E-2) /* Using the following R code to load the data and train the model using glmnet package. @@ -546,13 +513,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.053314311 */ val interceptR2 = 0.0 - val weightsR2 = Array(-0.005679651, 0.048967094, -0.093714016, -0.053314311) + val weightsR2 = Vectors.dense(-0.005679651, 0.048967094, -0.093714016, -0.053314311) - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.weights(0) ~== weightsR2(0) relTol 1E-2) - assert(model2.weights(1) ~== weightsR2(1) relTol 1E-2) - assert(model2.weights(2) ~== weightsR2(2) relTol 1E-3) - assert(model2.weights(3) ~== weightsR2(3) relTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= weightsR2 relTol 1E-2) } test("binary logistic regression with intercept with ElasticNet regularization") { @@ -583,13 +547,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.15458796 */ val interceptR1 = 0.57734851 - val weightsR1 = Array(-0.05310287, 0.0, -0.08849250, -0.15458796) + val weightsR1 = Vectors.dense(-0.05310287, 0.0, -0.08849250, -0.15458796) assert(model1.intercept ~== interceptR1 relTol 6E-3) - assert(model1.weights(0) ~== weightsR1(0) relTol 5E-3) - assert(model1.weights(1) ~== weightsR1(1) absTol 1E-3) - assert(model1.weights(2) ~== weightsR1(2) relTol 5E-3) - assert(model1.weights(3) ~== weightsR1(3) relTol 1E-3) + assert(model1.weights ~== weightsR1 absTol 5E-3) /* Using the following R code to load the data and train the model using glmnet package. @@ -611,13 +572,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.05350074 */ val interceptR2 = 0.51555993 - val weightsR2 = Array(0.0, 0.0, -0.18807395, -0.05350074) + val weightsR2 = Vectors.dense(0.0, 0.0, -0.18807395, -0.05350074) assert(model2.intercept ~== interceptR2 relTol 6E-3) - assert(model2.weights(0) ~== weightsR2(0) absTol 1E-3) - assert(model2.weights(1) ~== weightsR2(1) absTol 1E-3) - assert(model2.weights(2) ~== weightsR2(2) relTol 5E-3) - assert(model2.weights(3) ~== weightsR2(3) relTol 1E-2) + assert(model2.weights ~= weightsR2 absTol 1E-3) } test("binary logistic regression without intercept with ElasticNet regularization") { @@ -649,13 +607,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.142534158 */ val interceptR1 = 0.0 - val weightsR1 = Array(-0.001005743, 0.072577857, -0.081203769, -0.142534158) + val weightsR1 = Vectors.dense(-0.001005743, 0.072577857, -0.081203769, -0.142534158) assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights(0) ~== weightsR1(0) absTol 1E-2) - assert(model1.weights(1) ~== weightsR1(1) absTol 1E-2) - assert(model1.weights(2) ~== weightsR1(2) relTol 1E-3) - assert(model1.weights(3) ~== weightsR1(3) relTol 1E-2) + assert(model1.weights ~= weightsR1 absTol 1E-2) /* Using the following R code to load the data and train the model using glmnet package. @@ -677,13 +632,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 . */ val interceptR2 = 0.0 - val weightsR2 = Array(0.0, 0.03345223, -0.11304532, 0.0) + val weightsR2 = Vectors.dense(0.0, 0.03345223, -0.11304532, 0.0) - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.weights(0) ~== weightsR2(0) absTol 1E-3) - assert(model2.weights(1) ~== weightsR2(1) relTol 1E-2) - assert(model2.weights(2) ~== weightsR2(2) relTol 1E-2) - assert(model2.weights(3) ~== weightsR2(3) absTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= weightsR2 absTol 1E-3) } test("binary logistic regression with intercept with strong L1 regularization") { @@ -717,19 +669,13 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { }}} */ val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble) - val weightsTheory = Array(0.0, 0.0, 0.0, 0.0) + val weightsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptTheory relTol 1E-5) - assert(model1.weights(0) ~== weightsTheory(0) absTol 1E-6) - assert(model1.weights(1) ~== weightsTheory(1) absTol 1E-6) - assert(model1.weights(2) ~== weightsTheory(2) absTol 1E-6) - assert(model1.weights(3) ~== weightsTheory(3) absTol 1E-6) + assert(model1.weights ~= weightsTheory absTol 1E-6) assert(model2.intercept ~== interceptTheory relTol 1E-5) - assert(model2.weights(0) ~== weightsTheory(0) absTol 1E-6) - assert(model2.weights(1) ~== weightsTheory(1) absTol 1E-6) - assert(model2.weights(2) ~== weightsTheory(2) absTol 1E-6) - assert(model2.weights(3) ~== weightsTheory(3) absTol 1E-6) + assert(model2.weights ~= weightsTheory absTol 1E-6) /* Using the following R code to load the data and train the model using glmnet package. @@ -750,12 +696,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 . */ val interceptR = -0.248065 - val weightsR = Array(0.0, 0.0, 0.0, 0.0) + val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptR relTol 1E-5) - assert(model1.weights(0) ~== weightsR(0) absTol 1E-6) - assert(model1.weights(1) ~== weightsR(1) absTol 1E-6) - assert(model1.weights(2) ~== weightsR(2) absTol 1E-6) - assert(model1.weights(3) ~== weightsR(3) absTol 1E-6) + assert(model1.weights ~= weightsR absTol 1E-6) } } From 1903641e68ce7e7e657584bf45e91db6df357e41 Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Thu, 9 Jul 2015 19:31:31 -0700 Subject: [PATCH 100/149] [SPARK-8839] [SQL] ThriftServer2 will remove session and execution no matter it's finished or not. In my test, `sessions` and `executions` in ThriftServer2 is not the same number as the connection number. For example, if there are 200 clients connecting to the server, but it will have more than 200 `sessions` and `executions`. So if it reaches the `retainedStatements`, it has to remove some object which is not finished. So it may cause the exception described in [Jira Address](https://issues.apache.org/jira/browse/SPARK-8839) Author: huangzhaowei Closes #7239 from SaintBacchus/SPARK-8839 and squashes the following commits: cf7ef40 [huangzhaowei] Remove the a meanless funciton call 3e9a5a6 [huangzhaowei] Add a filter before take 9d5ceb8 [huangzhaowei] [SPARK-8839][SQL]ThriftServer2 will remove session and execution no matter it's finished or not. --- .../spark/sql/hive/thriftserver/HiveThriftServer2.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 700d994bb6a83..b7db80d93f852 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -179,6 +179,7 @@ object HiveThriftServer2 extends Logging { def onSessionClosed(sessionId: String): Unit = { sessionList(sessionId).finishTimestamp = System.currentTimeMillis onlineSessionNum -= 1 + trimSessionIfNecessary() } def onStatementStart( @@ -206,18 +207,20 @@ object HiveThriftServer2 extends Logging { executionList(id).detail = errorMessage executionList(id).state = ExecutionState.FAILED totalRunning -= 1 + trimExecutionIfNecessary() } def onStatementFinish(id: String): Unit = { executionList(id).finishTimestamp = System.currentTimeMillis executionList(id).state = ExecutionState.FINISHED totalRunning -= 1 + trimExecutionIfNecessary() } private def trimExecutionIfNecessary() = synchronized { if (executionList.size > retainedStatements) { val toRemove = math.max(retainedStatements / 10, 1) - executionList.take(toRemove).foreach { s => + executionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => executionList.remove(s._1) } } @@ -226,7 +229,7 @@ object HiveThriftServer2 extends Logging { private def trimSessionIfNecessary() = synchronized { if (sessionList.size > retainedSessions) { val toRemove = math.max(retainedSessions / 10, 1) - sessionList.take(toRemove).foreach { s => + sessionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => sessionList.remove(s._1) } } From d538919cc4fd3ab940d478c62dce1bae0270cfeb Mon Sep 17 00:00:00 2001 From: Michael Vogiatzis Date: Thu, 9 Jul 2015 19:53:23 -0700 Subject: [PATCH 101/149] [DOCS] Added important updateStateByKey details Runs for *all* existing keys and returning "None" will remove the key-value pair. Author: Michael Vogiatzis Closes #7229 from mvogiatzis/patch-1 and squashes the following commits: e7a2946 [Michael Vogiatzis] Updated updateStateByKey text 00283ed [Michael Vogiatzis] Removed space c2656f9 [Michael Vogiatzis] Moved description farther up 0a42551 [Michael Vogiatzis] Added important updateStateByKey details --- docs/streaming-programming-guide.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index e72d5580dae55..2f3013b533eb0 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -854,6 +854,8 @@ it with new information. To use this, you will have to do two steps. 1. Define the state update function - Specify with a function how to update the state using the previous state and the new values from an input stream. +In every batch, Spark will apply the state update function for all existing keys, regardless of whether they have new data in a batch or not. If the update function returns `None` then the key-value pair will be eliminated. + Let's illustrate this with an example. Say you want to maintain a running count of each word seen in a text data stream. Here, the running count is the state and it is an integer. We define the update function as: From e14b545d2dcbc4587688b4c46718d3680b0a2f67 Mon Sep 17 00:00:00 2001 From: Jonathan Alter Date: Fri, 10 Jul 2015 11:34:01 +0100 Subject: [PATCH 102/149] [SPARK-7977] [BUILD] Disallowing println Author: Jonathan Alter Closes #7093 from jonalter/SPARK-7977 and squashes the following commits: ccd44cc [Jonathan Alter] Changed println to log in ThreadingSuite 7fcac3e [Jonathan Alter] Reverting to println in ThreadingSuite 10724b6 [Jonathan Alter] Changing some printlns to logs in tests eeec1e7 [Jonathan Alter] Merge branch 'master' of github.com:apache/spark into SPARK-7977 0b1dcb4 [Jonathan Alter] More println cleanup aedaf80 [Jonathan Alter] Merge branch 'master' of github.com:apache/spark into SPARK-7977 925fd98 [Jonathan Alter] Merge branch 'master' of github.com:apache/spark into SPARK-7977 0c16fa3 [Jonathan Alter] Replacing some printlns with logs 45c7e05 [Jonathan Alter] Merge branch 'master' of github.com:apache/spark into SPARK-7977 5c8e283 [Jonathan Alter] Allowing println in audit-release examples 5b50da1 [Jonathan Alter] Allowing printlns in example files ca4b477 [Jonathan Alter] Merge branch 'master' of github.com:apache/spark into SPARK-7977 83ab635 [Jonathan Alter] Fixing new printlns 54b131f [Jonathan Alter] Merge branch 'master' of github.com:apache/spark into SPARK-7977 1cd8a81 [Jonathan Alter] Removing some unnecessary comments and printlns b837c3a [Jonathan Alter] Disallowing println --- .../main/scala/org/apache/spark/Logging.scala | 2 ++ .../org/apache/spark/api/r/RBackend.scala | 2 ++ .../scala/org/apache/spark/api/r/RRDD.scala | 2 ++ .../org/apache/spark/deploy/Client.scala | 30 ++++++++-------- .../apache/spark/deploy/ClientArguments.scala | 4 +++ .../org/apache/spark/deploy/RRunner.scala | 2 ++ .../org/apache/spark/deploy/SparkSubmit.scala | 18 ++++++++++ .../spark/deploy/SparkSubmitArguments.scala | 4 +++ .../spark/deploy/client/TestExecutor.scala | 2 ++ .../history/HistoryServerArguments.scala | 2 ++ .../spark/deploy/master/MasterArguments.scala | 2 ++ .../MesosClusterDispatcherArguments.scala | 6 ++++ .../spark/deploy/worker/DriverWrapper.scala | 2 ++ .../spark/deploy/worker/WorkerArguments.scala | 4 +++ .../CoarseGrainedExecutorBackend.scala | 4 +++ .../input/FixedLengthBinaryInputFormat.scala | 7 ++-- .../spark/network/nio/BlockMessage.scala | 22 ------------ .../spark/network/nio/BlockMessageArray.scala | 34 ++++--------------- .../spark/network/nio/ConnectionManager.scala | 4 +++ .../scala/org/apache/spark/rdd/PipedRDD.scala | 4 +++ .../scheduler/EventLoggingListener.scala | 2 ++ .../apache/spark/scheduler/JobLogger.scala | 2 ++ .../org/apache/spark/ui/JettyUtils.scala | 2 ++ .../apache/spark/ui/UIWorkloadGenerator.scala | 6 +++- .../org/apache/spark/util/Distribution.scala | 6 ++++ .../spark/util/random/XORShiftRandom.scala | 2 ++ .../org/apache/spark/DistributedSuite.scala | 2 ++ .../scala/org/apache/spark/FailureSuite.scala | 2 ++ .../org/apache/spark/FileServerSuite.scala | 2 ++ .../org/apache/spark/ThreadingSuite.scala | 6 ++-- .../spark/deploy/SparkSubmitSuite.scala | 4 +++ .../spark/deploy/SparkSubmitUtilsSuite.scala | 2 ++ .../WholeTextFileRecordReaderSuite.scala | 8 ++--- .../metrics/InputOutputMetricsSuite.scala | 2 ++ .../spark/scheduler/ReplayListenerSuite.scala | 2 ++ .../spark/util/ClosureCleanerSuite.scala | 2 ++ .../org/apache/spark/util/UtilsSuite.scala | 2 ++ .../util/collection/SizeTrackerSuite.scala | 4 +++ .../spark/util/collection/SorterSuite.scala | 10 +++--- .../src/main/scala/SparkApp.scala | 2 ++ .../src/main/scala/SparkApp.scala | 2 ++ .../src/main/scala/GraphxApp.scala | 2 ++ .../sbt_app_hive/src/main/scala/HiveApp.scala | 2 ++ .../src/main/scala/SparkApp.scala | 2 ++ .../sbt_app_sql/src/main/scala/SqlApp.scala | 2 ++ .../src/main/scala/StreamingApp.scala | 2 ++ .../apache/spark/examples/BroadcastTest.scala | 2 ++ .../spark/examples/CassandraCQLTest.scala | 2 ++ .../apache/spark/examples/CassandraTest.scala | 2 ++ .../spark/examples/DFSReadWriteTest.scala | 2 ++ .../spark/examples/DriverSubmissionTest.scala | 2 ++ .../apache/spark/examples/GroupByTest.scala | 2 ++ .../org/apache/spark/examples/HBaseTest.scala | 2 ++ .../org/apache/spark/examples/HdfsTest.scala | 2 ++ .../org/apache/spark/examples/LocalALS.scala | 2 ++ .../apache/spark/examples/LocalFileLR.scala | 2 ++ .../apache/spark/examples/LocalKMeans.scala | 2 ++ .../org/apache/spark/examples/LocalLR.scala | 2 ++ .../org/apache/spark/examples/LocalPi.scala | 2 ++ .../org/apache/spark/examples/LogQuery.scala | 2 ++ .../spark/examples/MultiBroadcastTest.scala | 2 ++ .../examples/SimpleSkewedGroupByTest.scala | 2 ++ .../spark/examples/SkewedGroupByTest.scala | 2 ++ .../org/apache/spark/examples/SparkALS.scala | 2 ++ .../apache/spark/examples/SparkHdfsLR.scala | 2 ++ .../apache/spark/examples/SparkKMeans.scala | 2 ++ .../org/apache/spark/examples/SparkLR.scala | 2 ++ .../apache/spark/examples/SparkPageRank.scala | 2 ++ .../org/apache/spark/examples/SparkPi.scala | 2 ++ .../org/apache/spark/examples/SparkTC.scala | 2 ++ .../spark/examples/SparkTachyonHdfsLR.scala | 2 ++ .../spark/examples/SparkTachyonPi.scala | 2 ++ .../spark/examples/graphx/Analytics.scala | 2 ++ .../examples/graphx/LiveJournalPageRank.scala | 2 ++ .../examples/graphx/SynthBenchmark.scala | 2 ++ .../examples/ml/CrossValidatorExample.scala | 2 ++ .../examples/ml/DecisionTreeExample.scala | 2 ++ .../examples/ml/DeveloperApiExample.scala | 2 ++ .../apache/spark/examples/ml/GBTExample.scala | 2 ++ .../examples/ml/LinearRegressionExample.scala | 2 ++ .../ml/LogisticRegressionExample.scala | 2 ++ .../spark/examples/ml/MovieLensALS.scala | 2 ++ .../spark/examples/ml/OneVsRestExample.scala | 2 ++ .../examples/ml/RandomForestExample.scala | 2 ++ .../examples/ml/SimpleParamsExample.scala | 2 ++ .../ml/SimpleTextClassificationPipeline.scala | 2 ++ .../examples/mllib/BinaryClassification.scala | 2 ++ .../spark/examples/mllib/Correlations.scala | 2 ++ .../examples/mllib/CosineSimilarity.scala | 2 ++ .../spark/examples/mllib/DatasetExample.scala | 2 ++ .../examples/mllib/DecisionTreeRunner.scala | 2 ++ .../examples/mllib/DenseGaussianMixture.scala | 2 ++ .../spark/examples/mllib/DenseKMeans.scala | 2 ++ .../examples/mllib/FPGrowthExample.scala | 2 ++ .../mllib/GradientBoostedTreesRunner.scala | 2 ++ .../spark/examples/mllib/LDAExample.scala | 2 ++ .../examples/mllib/LinearRegression.scala | 2 ++ .../spark/examples/mllib/MovieLensALS.scala | 2 ++ .../mllib/MultivariateSummarizer.scala | 2 ++ .../PowerIterationClusteringExample.scala | 3 +- .../examples/mllib/RandomRDDGeneration.scala | 2 ++ .../spark/examples/mllib/SampledRDDs.scala | 2 ++ .../examples/mllib/SparseNaiveBayes.scala | 2 ++ .../mllib/StreamingKMeansExample.scala | 2 ++ .../mllib/StreamingLinearRegression.scala | 2 ++ .../mllib/StreamingLogisticRegression.scala | 2 ++ .../spark/examples/mllib/TallSkinnyPCA.scala | 2 ++ .../spark/examples/mllib/TallSkinnySVD.scala | 2 ++ .../spark/examples/sql/RDDRelation.scala | 2 ++ .../examples/sql/hive/HiveFromSpark.scala | 2 ++ .../examples/streaming/ActorWordCount.scala | 2 ++ .../examples/streaming/CustomReceiver.scala | 2 ++ .../streaming/DirectKafkaWordCount.scala | 2 ++ .../examples/streaming/FlumeEventCount.scala | 2 ++ .../streaming/FlumePollingEventCount.scala | 2 ++ .../examples/streaming/HdfsWordCount.scala | 2 ++ .../examples/streaming/KafkaWordCount.scala | 2 ++ .../examples/streaming/MQTTWordCount.scala | 4 +++ .../examples/streaming/NetworkWordCount.scala | 2 ++ .../examples/streaming/RawNetworkGrep.scala | 2 ++ .../RecoverableNetworkWordCount.scala | 2 ++ .../streaming/SqlNetworkWordCount.scala | 2 ++ .../streaming/StatefulNetworkWordCount.scala | 2 ++ .../streaming/TwitterAlgebirdCMS.scala | 2 ++ .../streaming/TwitterAlgebirdHLL.scala | 2 ++ .../streaming/TwitterPopularTags.scala | 2 ++ .../examples/streaming/ZeroMQWordCount.scala | 2 ++ .../clickstream/PageViewGenerator.scala | 2 ++ .../clickstream/PageViewStream.scala | 2 ++ .../kafka/DirectKafkaStreamSuite.scala | 2 +- .../streaming/KinesisWordCountASL.scala | 2 ++ .../spark/graphx/util/BytecodeUtils.scala | 1 - .../spark/graphx/util/GraphGenerators.scala | 4 +-- .../graphx/util/BytecodeUtilsSuite.scala | 2 ++ .../mllib/util/KMeansDataGenerator.scala | 2 ++ .../mllib/util/LinearDataGenerator.scala | 2 ++ .../LogisticRegressionDataGenerator.scala | 2 ++ .../spark/mllib/util/MFDataGenerator.scala | 2 ++ .../spark/mllib/util/SVMDataGenerator.scala | 2 ++ .../spark/ml/feature/VectorIndexerSuite.scala | 10 +++--- .../spark/mllib/linalg/VectorsSuite.scala | 6 ++-- .../spark/mllib/stat/CorrelationSuite.scala | 6 ++-- .../tree/GradientBoostedTreesSuite.scala | 10 +++--- .../spark/mllib/util/NumericParserSuite.scala | 2 +- project/SparkBuild.scala | 4 +++ .../apache/spark/repl/SparkCommandLine.scala | 2 ++ .../org/apache/spark/repl/SparkILoop.scala | 2 ++ .../apache/spark/repl/SparkILoopInit.scala | 2 ++ .../org/apache/spark/repl/SparkIMain.scala | 2 ++ .../org/apache/spark/repl/SparkILoop.scala | 2 ++ .../org/apache/spark/repl/SparkIMain.scala | 4 +++ .../apache/spark/repl/SparkReplReporter.scala | 2 ++ scalastyle-config.xml | 12 +++---- .../expressions/codegen/package.scala | 2 ++ .../spark/sql/catalyst/plans/QueryPlan.scala | 2 ++ .../spark/sql/catalyst/util/package.scala | 2 ++ .../apache/spark/sql/types/StructType.scala | 2 ++ .../scala/org/apache/spark/sql/Column.scala | 2 ++ .../org/apache/spark/sql/DataFrame.scala | 6 ++++ .../spark/sql/execution/debug/package.scala | 16 ++++----- .../hive/thriftserver/SparkSQLCLIDriver.scala | 12 ++++--- .../apache/spark/sql/hive/HiveContext.scala | 5 +-- .../org/apache/spark/sql/hive/HiveQl.scala | 5 +-- .../spark/sql/hive/client/ClientWrapper.scala | 2 ++ .../regression-test-SPARK-8489/Main.scala | 2 ++ .../sql/hive/HiveMetastoreCatalogSuite.scala | 6 ++-- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 2 ++ .../sql/hive/InsertIntoHiveTableSuite.scala | 2 -- .../sql/hive/MetastoreDataSourcesSuite.scala | 6 ++-- .../sql/hive/execution/HiveUDFSuite.scala | 1 - .../spark/streaming/dstream/DStream.scala | 2 ++ .../spark/streaming/util/RawTextSender.scala | 2 ++ .../spark/streaming/util/RecurringTimer.scala | 4 +-- .../spark/streaming/MasterFailureTest.scala | 4 +++ .../scheduler/JobGeneratorSuite.scala | 1 - .../spark/tools/GenerateMIMAIgnore.scala | 8 +++++ .../tools/JavaAPICompletenessChecker.scala | 4 +++ .../spark/tools/StoragePerfTester.scala | 4 +++ .../yarn/ApplicationMasterArguments.scala | 4 +++ .../org/apache/spark/deploy/yarn/Client.scala | 2 +- .../spark/deploy/yarn/ClientArguments.scala | 4 +++ .../spark/deploy/yarn/YarnClusterSuite.scala | 4 +++ 182 files changed, 478 insertions(+), 135 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 7fcb7830e7b0b..87ab099267b2f 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -121,6 +121,7 @@ trait Logging { if (usingLog4j12) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements if (!log4j12Initialized) { + // scalastyle:off println if (Utils.isInInterpreter) { val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties" Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match { @@ -141,6 +142,7 @@ trait Logging { System.err.println(s"Spark was unable to load $defaultLogProps") } } + // scalastyle:on println } } Logging.initialized = true diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 1a5f2bca26c2b..b7e72d4d0ed0b 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -95,7 +95,9 @@ private[spark] class RBackend { private[spark] object RBackend extends Logging { def main(args: Array[String]): Unit = { if (args.length < 1) { + // scalastyle:off println System.err.println("Usage: RBackend ") + // scalastyle:on println System.exit(-1) } val sparkRBackend = new RBackend() diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 524676544d6f5..ff1702f7dea48 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -161,7 +161,9 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( dataOut.write(elem.asInstanceOf[Array[Byte]]) } else if (deserializer == SerializationFormats.STRING) { // write string(for StringRRDD) + // scalastyle:off println printOut.println(elem) + // scalastyle:on println } } diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 71f7e2129116f..f03875a3e8c89 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -118,26 +118,26 @@ private class ClientEndpoint( def pollAndReportStatus(driverId: String) { // Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread // is fine. - println("... waiting before polling master for driver state") + logInfo("... waiting before polling master for driver state") Thread.sleep(5000) - println("... polling master for driver state") + logInfo("... polling master for driver state") val statusResponse = activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId)) statusResponse.found match { case false => - println(s"ERROR: Cluster master did not recognize $driverId") + logError(s"ERROR: Cluster master did not recognize $driverId") System.exit(-1) case true => - println(s"State of $driverId is ${statusResponse.state.get}") + logInfo(s"State of $driverId is ${statusResponse.state.get}") // Worker node, if present (statusResponse.workerId, statusResponse.workerHostPort, statusResponse.state) match { case (Some(id), Some(hostPort), Some(DriverState.RUNNING)) => - println(s"Driver running on $hostPort ($id)") + logInfo(s"Driver running on $hostPort ($id)") case _ => } // Exception, if present statusResponse.exception.map { e => - println(s"Exception from cluster was: $e") + logError(s"Exception from cluster was: $e") e.printStackTrace() System.exit(-1) } @@ -148,7 +148,7 @@ private class ClientEndpoint( override def receive: PartialFunction[Any, Unit] = { case SubmitDriverResponse(master, success, driverId, message) => - println(message) + logInfo(message) if (success) { activeMasterEndpoint = master pollAndReportStatus(driverId.get) @@ -158,7 +158,7 @@ private class ClientEndpoint( case KillDriverResponse(master, driverId, success, message) => - println(message) + logInfo(message) if (success) { activeMasterEndpoint = master pollAndReportStatus(driverId) @@ -169,13 +169,13 @@ private class ClientEndpoint( override def onDisconnected(remoteAddress: RpcAddress): Unit = { if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master $remoteAddress.") + logError(s"Error connecting to master $remoteAddress.") lostMasters += remoteAddress // Note that this heuristic does not account for the fact that a Master can recover within // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This // is not currently a concern, however, because this client does not retry submissions. if (lostMasters.size >= masterEndpoints.size) { - println("No master is available, exiting.") + logError("No master is available, exiting.") System.exit(-1) } } @@ -183,18 +183,18 @@ private class ClientEndpoint( override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master ($remoteAddress).") - println(s"Cause was: $cause") + logError(s"Error connecting to master ($remoteAddress).") + logError(s"Cause was: $cause") lostMasters += remoteAddress if (lostMasters.size >= masterEndpoints.size) { - println("No master is available, exiting.") + logError("No master is available, exiting.") System.exit(-1) } } } override def onError(cause: Throwable): Unit = { - println(s"Error processing messages, exiting.") + logError(s"Error processing messages, exiting.") cause.printStackTrace() System.exit(-1) } @@ -209,10 +209,12 @@ private class ClientEndpoint( */ object Client { def main(args: Array[String]) { + // scalastyle:off println if (!sys.props.contains("SPARK_SUBMIT")) { println("WARNING: This client is deprecated and will be removed in a future version of Spark") println("Use ./bin/spark-submit with \"--master spark://host:port\"") } + // scalastyle:on println val conf = new SparkConf() val driverArgs = new ClientArguments(args) diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 42d3296062e6d..72cc330a398da 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -72,9 +72,11 @@ private[deploy] class ClientArguments(args: Array[String]) { cmd = "launch" if (!ClientArguments.isValidJarUrl(_jarUrl)) { + // scalastyle:off println println(s"Jar url '${_jarUrl}' is not in valid format.") println(s"Must be a jar file path in URL format " + "(e.g. hdfs://host:port/XX.jar, file:///XX.jar)") + // scalastyle:on println printUsageAndExit(-1) } @@ -110,7 +112,9 @@ private[deploy] class ClientArguments(args: Array[String]) { | (default: $DEFAULT_SUPERVISE) | -v, --verbose Print more debugging output """.stripMargin + // scalastyle:off println System.err.println(usage) + // scalastyle:on println System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index e99779f299785..4165740312e03 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -85,7 +85,9 @@ object RRunner { } System.exit(returnCode) } else { + // scalastyle:off println System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds") + // scalastyle:on println System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index b1d6ec209d62b..4cec9017b8adb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -82,6 +82,7 @@ object SparkSubmit { private val CLASS_NOT_FOUND_EXIT_STATUS = 101 + // scalastyle:off println // Exposed for testing private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) private[spark] var printStream: PrintStream = System.err @@ -102,11 +103,14 @@ object SparkSubmit { printStream.println("Type --help for more information.") exitFn(0) } + // scalastyle:on println def main(args: Array[String]): Unit = { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { + // scalastyle:off println printStream.println(appArgs) + // scalastyle:on println } appArgs.action match { case SparkSubmitAction.SUBMIT => submit(appArgs) @@ -160,7 +164,9 @@ object SparkSubmit { // makes the message printed to the output by the JVM not very helpful. Instead, // detect exceptions with empty stack traces here, and treat them differently. if (e.getStackTrace().length == 0) { + // scalastyle:off println printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") + // scalastyle:on println exitFn(1) } else { throw e @@ -178,7 +184,9 @@ object SparkSubmit { // to use the legacy gateway if the master endpoint turns out to be not a REST server. if (args.isStandaloneCluster && args.useRest) { try { + // scalastyle:off println printStream.println("Running Spark using the REST application submission protocol.") + // scalastyle:on println doRunMain() } catch { // Fail over to use the legacy submission gateway @@ -558,6 +566,7 @@ object SparkSubmit { sysProps: Map[String, String], childMainClass: String, verbose: Boolean): Unit = { + // scalastyle:off println if (verbose) { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") @@ -565,6 +574,7 @@ object SparkSubmit { printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") printStream.println("\n") } + // scalastyle:on println val loader = if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { @@ -592,8 +602,10 @@ object SparkSubmit { case e: ClassNotFoundException => e.printStackTrace(printStream) if (childMainClass.contains("thriftserver")) { + // scalastyle:off println printStream.println(s"Failed to load main class $childMainClass.") printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") + // scalastyle:on println } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } @@ -766,7 +778,9 @@ private[spark] object SparkSubmitUtils { brr.setRoot(repo) brr.setName(s"repo-${i + 1}") cr.add(brr) + // scalastyle:off println printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + // scalastyle:on println } } @@ -829,7 +843,9 @@ private[spark] object SparkSubmitUtils { val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version) val dd = new DefaultDependencyDescriptor(ri, false, false) dd.addDependencyConfiguration(ivyConfName, ivyConfName) + // scalastyle:off println printStream.println(s"${dd.getDependencyId} added as a dependency") + // scalastyle:on println md.addDependency(dd) } } @@ -896,9 +912,11 @@ private[spark] object SparkSubmitUtils { ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) new File(alternateIvyCache, "jars") } + // scalastyle:off println printStream.println( s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // scalastyle:on println // create a pattern matcher ivySettings.addMatcher(new GlobPatternMatcher) // create the dependency resolvers diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 6e3c0b21b33c2..ebb39c354dff1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -79,6 +79,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { val defaultProperties = new HashMap[String, String]() + // scalastyle:off println if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => Utils.getPropertiesFromFile(filename).foreach { case (k, v) => @@ -86,6 +87,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") } } + // scalastyle:on println defaultProperties } @@ -452,6 +454,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { + // scalastyle:off println val outStream = SparkSubmit.printStream if (unknownParam != null) { outStream.println("Unknown/unsupported param " + unknownParam) @@ -541,6 +544,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S outStream.println("CLI options:") outStream.println(getSqlShellOptions()) } + // scalastyle:on println SparkSubmit.exitFn(exitCode) } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala index c5ac45c6730d3..a98b1fa8f83a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala @@ -19,7 +19,9 @@ package org.apache.spark.deploy.client private[spark] object TestExecutor { def main(args: Array[String]) { + // scalastyle:off println println("Hello world!") + // scalastyle:on println while (true) { Thread.sleep(1000) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 4692d22651c93..18265df9faa2c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -56,6 +56,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin Utils.loadDefaultSparkProperties(conf, propertiesFile) private def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( """ |Usage: HistoryServer [options] @@ -84,6 +85,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin | spark.history.fs.updateInterval How often to reload log data from storage | (in seconds, default: 10) |""".stripMargin) + // scalastyle:on println System.exit(exitCode) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index 435b9b12f83b8..44cefbc77f08e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -85,6 +85,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) { * Print usage and exit JVM with the given exit code. */ private def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( "Usage: Master [options]\n" + "\n" + @@ -95,6 +96,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) { " --webui-port PORT Port for web UI (default: 8080)\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index 894cb78d8591a..5accaf78d0a51 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -54,7 +54,9 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case ("--master" | "-m") :: value :: tail => if (!value.startsWith("mesos://")) { + // scalastyle:off println System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)") + // scalastyle:on println System.exit(1) } masterUrl = value.stripPrefix("mesos://") @@ -73,7 +75,9 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case Nil => { if (masterUrl == null) { + // scalastyle:off println System.err.println("--master is required") + // scalastyle:on println printUsageAndExit(1) } } @@ -83,6 +87,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: } private def printUsageAndExit(exitCode: Int): Unit = { + // scalastyle:off println System.err.println( "Usage: MesosClusterDispatcher [options]\n" + "\n" + @@ -96,6 +101,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: " Zookeeper for persistence\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index d1a12b01e78f7..2d6be3042c905 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -60,7 +60,9 @@ object DriverWrapper { rpcEnv.shutdown() case _ => + // scalastyle:off println System.err.println("Usage: DriverWrapper [options]") + // scalastyle:on println System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 1d2ecab517613..e89d076802215 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -121,6 +121,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { * Print usage and exit JVM with the given exit code. */ def printUsageAndExit(exitCode: Int) { + // scalastyle:off println System.err.println( "Usage: Worker [options] \n" + "\n" + @@ -136,6 +137,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { " --webui-port PORT Port for web UI (default: 8081)\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + " Default is conf/spark-defaults.conf.") + // scalastyle:on println System.exit(exitCode) } @@ -160,7 +162,9 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { } catch { case e: Exception => { totalMb = 2*1024 + // scalastyle:off println System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") + // scalastyle:on println } } // Leave out 1 GB for the operating system, but don't return a negative memory size diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 34d4cfdca7732..fcd76ec52742a 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -235,7 +235,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { argv = tail case Nil => case tail => + // scalastyle:off println System.err.println(s"Unrecognized options: ${tail.mkString(" ")}") + // scalastyle:on println printUsageAndExit() } } @@ -249,6 +251,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } private def printUsageAndExit() = { + // scalastyle:off println System.err.println( """ |"Usage: CoarseGrainedExecutorBackend [options] @@ -262,6 +265,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { | --worker-url | --user-class-path |""".stripMargin) + // scalastyle:on println System.exit(1) } diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala index c219d21fbefa9..532850dd57716 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -21,6 +21,8 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{BytesWritable, LongWritable} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} + +import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil /** @@ -39,7 +41,8 @@ private[spark] object FixedLengthBinaryInputFormat { } private[spark] class FixedLengthBinaryInputFormat - extends FileInputFormat[LongWritable, BytesWritable] { + extends FileInputFormat[LongWritable, BytesWritable] + with Logging { private var recordLength = -1 @@ -51,7 +54,7 @@ private[spark] class FixedLengthBinaryInputFormat recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) } if (recordLength <= 0) { - println("record length is less than 0, file cannot be split") + logDebug("record length is less than 0, file cannot be split") false } else { true diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala index 67a376102994c..79cb0640c8672 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala @@ -57,16 +57,6 @@ private[nio] class BlockMessage() { } def set(buffer: ByteBuffer) { - /* - println() - println("BlockMessage: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ typ = buffer.getInt() val idLength = buffer.getInt() val idBuilder = new StringBuilder(idLength) @@ -138,18 +128,6 @@ private[nio] class BlockMessage() { buffers += data } - /* - println() - println("BlockMessage: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ Message.createBufferMessage(buffers) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala index 7d0806f0c2580..f1c9ea8b64ca3 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -43,16 +43,6 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) val newBlockMessages = new ArrayBuffer[BlockMessage]() val buffer = bufferMessage.buffers(0) buffer.clear() - /* - println() - println("BlockMessageArray: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ while (buffer.remaining() > 0) { val size = buffer.getInt() logDebug("Creating block message of size " + size + " bytes") @@ -86,23 +76,11 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) logDebug("Buffer list:") buffers.foreach((x: ByteBuffer) => logDebug("" + x)) - /* - println() - println("BlockMessageArray: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ Message.createBufferMessage(buffers) } } -private[nio] object BlockMessageArray { +private[nio] object BlockMessageArray extends Logging { def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { val newBlockMessageArray = new BlockMessageArray() @@ -123,10 +101,10 @@ private[nio] object BlockMessageArray { } } val blockMessageArray = new BlockMessageArray(blockMessages) - println("Block message array created") + logDebug("Block message array created") val bufferMessage = blockMessageArray.toBufferMessage - println("Converted to buffer message") + logDebug("Converted to buffer message") val totalSize = bufferMessage.size val newBuffer = ByteBuffer.allocate(totalSize) @@ -138,10 +116,11 @@ private[nio] object BlockMessageArray { }) newBuffer.flip val newBufferMessage = Message.createBufferMessage(newBuffer) - println("Copied to new buffer message, size = " + newBufferMessage.size) + logDebug("Copied to new buffer message, size = " + newBufferMessage.size) val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) - println("Converted back to block message array") + logDebug("Converted back to block message array") + // scalastyle:off println newBlockMessageArray.foreach(blockMessage => { blockMessage.getType match { case BlockMessage.TYPE_PUT_BLOCK => { @@ -154,6 +133,7 @@ private[nio] object BlockMessageArray { } } }) + // scalastyle:on println } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index c0bca2c4bc994..9143918790381 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -1016,7 +1016,9 @@ private[spark] object ConnectionManager { val conf = new SparkConf val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + // scalastyle:off println println("Received [" + msg + "] from [" + id + "]") + // scalastyle:on println None }) @@ -1033,6 +1035,7 @@ private[spark] object ConnectionManager { System.gc() } + // scalastyle:off println def testSequentialSending(manager: ConnectionManager) { println("--------------------------") println("Sequential Sending") @@ -1150,4 +1153,5 @@ private[spark] object ConnectionManager { println() } } + // scalastyle:on println } diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index dc60d48927624..defdabf95ac4b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -123,7 +123,9 @@ private[spark] class PipedRDD[T: ClassTag]( new Thread("stderr reader for " + command) { override def run() { for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { + // scalastyle:off println System.err.println(line) + // scalastyle:on println } } }.start() @@ -133,6 +135,7 @@ private[spark] class PipedRDD[T: ClassTag]( override def run() { val out = new PrintWriter(proc.getOutputStream) + // scalastyle:off println // input the pipe context firstly if (printPipeContext != null) { printPipeContext(out.println(_)) @@ -144,6 +147,7 @@ private[spark] class PipedRDD[T: ClassTag]( out.println(elem) } } + // scalastyle:on println out.close() } }.start() diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 529a5b2bf1a0d..62b05033a9281 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -140,7 +140,9 @@ private[spark] class EventLoggingListener( /** Log the event as JSON. */ private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false) { val eventJson = JsonProtocol.sparkEventToJson(event) + // scalastyle:off println writer.foreach(_.println(compact(render(eventJson)))) + // scalastyle:on println if (flushLogger) { writer.foreach(_.flush()) hadoopDataStream.foreach(hadoopFlushMethod.invoke(_)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index e55b76c36cc5f..f96eb8ca0ae00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -125,7 +125,9 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener val date = new Date(System.currentTimeMillis()) writeInfo = dateFormat.get.format(date) + ": " + info } + // scalastyle:off println jobIdToPrintWriter.get(jobId).foreach(_.println(writeInfo)) + // scalastyle:on println } /** diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index f413c1d37fbb6..c8356467fab87 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -68,7 +68,9 @@ private[spark] object JettyUtils extends Logging { response.setStatus(HttpServletResponse.SC_OK) val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + // scalastyle:off println response.getWriter.println(servletParams.extractFn(result)) + // scalastyle:on println } else { response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index ba03acdb38cc5..5a8c2914314c2 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -38,9 +38,11 @@ private[spark] object UIWorkloadGenerator { def main(args: Array[String]) { if (args.length < 3) { + // scalastyle:off println println( - "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " + + "Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " + "[master] [FIFO|FAIR] [#job set (4 jobs per set)]") + // scalastyle:on println System.exit(1) } @@ -96,6 +98,7 @@ private[spark] object UIWorkloadGenerator { for ((desc, job) <- jobs) { new Thread { override def run() { + // scalastyle:off println try { setProperties(desc) job() @@ -106,6 +109,7 @@ private[spark] object UIWorkloadGenerator { } finally { barrier.release() } + // scalastyle:on println } }.start Thread.sleep(INTER_JOB_WAIT_MS) diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala index 1bab707235b89..950b69f7db641 100644 --- a/core/src/main/scala/org/apache/spark/util/Distribution.scala +++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala @@ -52,9 +52,11 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va } def showQuantiles(out: PrintStream = System.out): Unit = { + // scalastyle:off println out.println("min\t25%\t50%\t75%\tmax") getQuantiles(defaultProbabilities).foreach{q => out.print(q + "\t")} out.println + // scalastyle:on println } def statCounter: StatCounter = StatCounter(data.slice(startIdx, endIdx)) @@ -64,8 +66,10 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va * @param out */ def summary(out: PrintStream = System.out) { + // scalastyle:off println out.println(statCounter) showQuantiles(out) + // scalastyle:on println } } @@ -80,8 +84,10 @@ private[spark] object Distribution { } def showQuantiles(out: PrintStream = System.out, quantiles: Traversable[Double]) { + // scalastyle:off println out.println("min\t25%\t50%\t75%\tmax") quantiles.foreach{q => out.print(q + "\t")} out.println + // scalastyle:on println } } diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index c4a7b4441c85c..85fb923cd9bc7 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -70,12 +70,14 @@ private[spark] object XORShiftRandom { * @param args takes one argument - the number of random numbers to generate */ def main(args: Array[String]): Unit = { + // scalastyle:off println if (args.length != 1) { println("Benchmark of XORShiftRandom vis-a-vis java.util.Random") println("Usage: XORShiftRandom number_of_random_numbers_to_generate") System.exit(1) } println(benchmark(args(0).toInt)) + // scalastyle:on println } /** diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 9c191ed52206d..2300bcff4f118 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -107,7 +107,9 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex sc = new SparkContext(clusterUrl, "test") val accum = sc.accumulator(0) val thrown = intercept[SparkException] { + // scalastyle:off println sc.parallelize(1 to 10, 10).foreach(x => println(x / 0)) + // scalastyle:on println } assert(thrown.getClass === classOf[SparkException]) assert(thrown.getMessage.contains("failed 4 times")) diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index a8c8c6f73fb5a..b099cd3fb7965 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -130,7 +130,9 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { // Non-serializable closure in foreach function val thrown2 = intercept[SparkException] { + // scalastyle:off println sc.parallelize(1 to 10, 2).foreach(x => println(a)) + // scalastyle:on println } assert(thrown2.getClass === classOf[SparkException]) assert(thrown2.getMessage.contains("NotSerializableException") || diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 6e65b0a8f6c76..876418aa13029 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -51,7 +51,9 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { val textFile = new File(testTempDir, "FileServerSuite.txt") val pw = new PrintWriter(textFile) + // scalastyle:off println pw.println("100") + // scalastyle:on println pw.close() val jarFile = new File(testTempDir, "test.jar") diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index 6580139df6c60..48509f0759a3b 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -36,7 +36,7 @@ object ThreadingSuiteState { } } -class ThreadingSuite extends SparkFunSuite with LocalSparkContext { +class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { test("accessing SparkContext form a different thread") { sc = new SparkContext("local", "test") @@ -130,8 +130,6 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext { Thread.sleep(100) } if (running.get() != 4) { - println("Waited 1 second without seeing runningThreads = 4 (it was " + - running.get() + "); failing test") ThreadingSuiteState.failed.set(true) } number @@ -143,6 +141,8 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext { } sem.acquire(2) if (ThreadingSuiteState.failed.get()) { + logError("Waited 1 second without seeing runningThreads = 4 (it was " + + ThreadingSuiteState.runningThreads.get() + "); failing test") fail("One or more threads didn't see runningThreads = 4") } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 2e05dec99b6bf..1b64c329b5d4b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -51,9 +51,11 @@ class SparkSubmitSuite /** Simple PrintStream that reads data into a buffer */ private class BufferPrintStream extends PrintStream(noOpOutputStream) { var lineBuffer = ArrayBuffer[String]() + // scalastyle:off println override def println(line: String) { lineBuffer += line } + // scalastyle:on println } /** Returns true if the script exits and the given search string is printed. */ @@ -81,6 +83,7 @@ class SparkSubmitSuite } } + // scalastyle:off println test("prints usage on empty input") { testPrematureExit(Array[String](), "Usage: spark-submit") } @@ -491,6 +494,7 @@ class SparkSubmitSuite appArgs.executorMemory should be ("2.3g") } } + // scalastyle:on println // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. private def runSparkSubmit(args: Seq[String]): Unit = { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index c9b435a9228d3..01ece1a10f46d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -41,9 +41,11 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { /** Simple PrintStream that reads data into a buffer */ private class BufferPrintStream extends PrintStream(noOpOutputStream) { var lineBuffer = ArrayBuffer[String]() + // scalastyle:off println override def println(line: String) { lineBuffer += line } + // scalastyle:on println } override def beforeAll() { diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 63947df3d43a2..8a199459c1ddf 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.hadoop.io.Text -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.util.Utils import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} @@ -36,7 +36,7 @@ import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, Gzi * [[org.apache.spark.input.WholeTextFileRecordReader WholeTextFileRecordReader]]. A temporary * directory is created as fake input. Temporal storage would be deleted in the end. */ -class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAll { +class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { private var sc: SparkContext = _ private var factory: CompressionCodecFactory = _ @@ -85,7 +85,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl */ test("Correctness of WholeTextFileRecordReader.") { val dir = Utils.createTempDir() - println(s"Local disk address is ${dir.toString}.") + logInfo(s"Local disk address is ${dir.toString}.") WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => createNativeFile(dir, filename, contents, false) @@ -109,7 +109,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl test("Correctness of WholeTextFileRecordReader with GzipCodec.") { val dir = Utils.createTempDir() - println(s"Local disk address is ${dir.toString}.") + logInfo(s"Local disk address is ${dir.toString}.") WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => createNativeFile(dir, filename, contents, true) diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 9e4d34fb7d382..d3218a548efc7 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -60,7 +60,9 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt") val pw = new PrintWriter(new FileWriter(tmpFile)) for (x <- 1 to numRecords) { + // scalastyle:off println pw.println(RandomUtils.nextInt(0, numBuckets)) + // scalastyle:on println } pw.close() diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index ff3fa95ec32ae..4e3defb43a021 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -52,8 +52,10 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey", None) val applicationEnd = SparkListenerApplicationEnd(1000L) + // scalastyle:off println writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationStart)))) writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationEnd)))) + // scalastyle:on println writer.close() val conf = EventLoggingListenerSuite.getLoggingConf(logFilePath) diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 1053c6caf7718..480722a5ac182 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -375,6 +375,7 @@ class TestCreateNullValue { // parameters of the closure constructor. This allows us to test whether // null values are created correctly for each type. val nestedClosure = () => { + // scalastyle:off println if (s.toString == "123") { // Don't really output them to avoid noisy println(bo) println(c) @@ -389,6 +390,7 @@ class TestCreateNullValue { val closure = () => { println(getX) } + // scalastyle:on println ClosureCleaner.clean(closure) } nestedClosure() diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 251a797dc28a2..c7638507c88c6 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -684,7 +684,9 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { val buffer = new CircularBuffer(25) val stream = new java.io.PrintStream(buffer, true, "UTF-8") + // scalastyle:off println stream.println("test circular test circular test circular test circular test circular") + // scalastyle:on println assert(buffer.toString === "t circular test circular\n") } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala index 5a5919fca2469..4f382414a8dd7 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala @@ -103,7 +103,9 @@ private object SizeTrackerSuite { */ def main(args: Array[String]): Unit = { if (args.size < 1) { + // scalastyle:off println println("Usage: SizeTrackerSuite [num elements]") + // scalastyle:on println System.exit(1) } val numElements = args(0).toInt @@ -180,11 +182,13 @@ private object SizeTrackerSuite { baseTimes: Seq[Long], sampledTimes: Seq[Long], unsampledTimes: Seq[Long]): Unit = { + // scalastyle:off println println(s"Average times for $testName (ms):") println(" Base - " + averageTime(baseTimes)) println(" SizeTracker (sampled) - " + averageTime(sampledTimes)) println(" SizeEstimator (unsampled) - " + averageTime(unsampledTimes)) println() + // scalastyle:on println } def time(f: => Unit): Long = { diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index b2f5d9009ee5d..fefa5165db197 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.util.collection import java.lang.{Float => JFloat, Integer => JInteger} import java.util.{Arrays, Comparator} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.util.random.XORShiftRandom -class SorterSuite extends SparkFunSuite { +class SorterSuite extends SparkFunSuite with Logging { test("equivalent to Arrays.sort") { val rand = new XORShiftRandom(123) @@ -74,7 +74,7 @@ class SorterSuite extends SparkFunSuite { /** Runs an experiment several times. */ def runExperiment(name: String, skip: Boolean = false)(f: => Unit, prepare: () => Unit): Unit = { if (skip) { - println(s"Skipped experiment $name.") + logInfo(s"Skipped experiment $name.") return } @@ -86,11 +86,11 @@ class SorterSuite extends SparkFunSuite { while (i < 10) { val time = org.apache.spark.util.Utils.timeIt(1)(f, Some(prepare)) next10 += time - println(s"$name: Took $time ms") + logInfo(s"$name: Took $time ms") i += 1 } - println(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)") + logInfo(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)") } /** diff --git a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala index fc03fec9866a6..61d91c70e9709 100644 --- a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.util.Try @@ -59,3 +60,4 @@ object SimpleApp { } } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala index 0be8e64fbfabd..9f7ae75d0b477 100644 --- a/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.util.Try @@ -37,3 +38,4 @@ object SimpleApp { } } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala b/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala index 24c7f8d667296..2f0b6ef9a5672 100644 --- a/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala +++ b/dev/audit-release/sbt_app_graphx/src/main/scala/GraphxApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import org.apache.spark.{SparkContext, SparkConf} @@ -51,3 +52,4 @@ object GraphXApp { println("Test succeeded") } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala index 5111bc0adb772..4a980ec071ae4 100644 --- a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala +++ b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.collection.mutable.{ListBuffer, Queue} @@ -55,3 +56,4 @@ object SparkSqlExample { sc.stop() } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala index 9f85066501472..adc25b57d6aa5 100644 --- a/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.util.Try @@ -31,3 +32,4 @@ object SimpleApp { } } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala index cc86ef45858c9..69c1154dc0955 100644 --- a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala +++ b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.collection.mutable.{ListBuffer, Queue} @@ -57,3 +58,4 @@ object SparkSqlExample { sc.stop() } } +// scalastyle:on println diff --git a/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala b/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala index 58a662bd9b2e8..d6a074687f4a1 100644 --- a/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala +++ b/dev/audit-release/sbt_app_streaming/src/main/scala/StreamingApp.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package main.scala import scala.collection.mutable.{ListBuffer, Queue} @@ -61,3 +62,4 @@ object SparkStreamingExample { ssc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 4c129dbe2d12d..d812262fd87dc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} @@ -52,3 +53,4 @@ object BroadcastTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 023bb3ee2d108..36832f51d2ad4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ + // scalastyle:off println package org.apache.spark.examples import java.nio.ByteBuffer @@ -140,3 +141,4 @@ object CassandraCQLTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index ec689474aecb0..96ef3e198e380 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.nio.ByteBuffer @@ -130,6 +131,7 @@ object CassandraTest { sc.stop() } } +// scalastyle:on println /* create keyspace casDemo; diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala index 1f12034ce0f57..d651fe4d6ee75 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.io.File @@ -136,3 +137,4 @@ object DFSReadWriteTest { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index e757283823fc3..c42df2b8845d2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.collection.JavaConversions._ @@ -46,3 +47,4 @@ object DriverSubmissionTest { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index 15f6678648b29..fa4a3afeecd19 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -53,3 +54,4 @@ object GroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index 95c96111c9b1f..244742327a907 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.hadoop.hbase.client.HBaseAdmin @@ -62,3 +63,4 @@ object HBaseTest { admin.close() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala index ed2b38e2ca6f8..124dc9af6390f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark._ @@ -41,3 +42,4 @@ object HdfsTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index 3d5259463003d..af5f216f28ba4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.commons.math3.linear._ @@ -142,3 +143,4 @@ object LocalALS { new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index ac2ea35bbd0e0..9c8aae53cf48d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -73,3 +74,4 @@ object LocalFileLR { println("Final w: " + w) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index 04fc0a033014a..e7b28d38bdfc6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -119,3 +120,4 @@ object LocalKMeans { println("Final centers: " + kPoints) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index c3fc74a116c0a..4f6b092a59ca5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -77,3 +78,4 @@ object LocalLR { println("Final w: " + w) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala index ee6b3ee34aeb2..3d923625f11b6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -33,3 +34,4 @@ object LocalPi { println("Pi is roughly " + 4 * count / 100000.0) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 75c82117cbad2..a80de10f4610a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} @@ -83,3 +84,4 @@ object LogQuery { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 2a5c0c0defe13..61ce9db914f9f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.rdd.RDD @@ -53,3 +54,4 @@ object MultiBroadcastTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index 5291ab81f459e..3b0b00fe4dd0a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -67,3 +68,4 @@ object SimpleSkewedGroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 017d4e1e5ce13..719e2176fed3f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -57,3 +58,4 @@ object SkewedGroupByTest { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 30c4261551837..69799b7c2bb30 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.commons.math3.linear._ @@ -144,3 +145,4 @@ object SparkALS { new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 9099c2fcc90b3..505ea5a4c7a85 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -97,3 +98,4 @@ object SparkHdfsLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index b514d9123f5e7..c56e1124ad415 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import breeze.linalg.{Vector, DenseVector, squaredDistance} @@ -100,3 +101,4 @@ object SparkKMeans { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 1e6b4fb0c7514..d265c227f4ed2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -86,3 +87,4 @@ object SparkLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index bd7894f184c4c..0fd79660dd196 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import org.apache.spark.SparkContext._ @@ -74,3 +75,4 @@ object SparkPageRank { ctx.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 35b8dd6c29b66..818d4f2b81f82 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -37,3 +38,4 @@ object SparkPi { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index 772cd897f5140..95072071ccddb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.util.Random @@ -70,3 +71,4 @@ object SparkTC { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index 4393b99e636b6..cfbdae02212a5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import java.util.Random @@ -94,3 +95,4 @@ object SparkTachyonHdfsLR { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala index 7743f7968b100..e46ac655beb58 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples import scala.math.random @@ -46,3 +47,4 @@ object SparkTachyonPi { spark.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 409721b01c8fd..8dd6c9706e7df 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx import scala.collection.mutable @@ -151,3 +152,4 @@ object Analytics extends Logging { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala index f6f8d9f90c275..da3ffca1a6f2a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx /** @@ -42,3 +43,4 @@ object LiveJournalPageRank { Analytics.main(args.patch(0, List("pagerank"), 0)) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 3ec20d594b784..46e52aacd90bb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.graphx import org.apache.spark.SparkContext._ @@ -128,3 +129,4 @@ object SynthBenchmark { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index 6c0af20461d3b..14b358d46f6ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -110,3 +111,4 @@ object CrossValidatorExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 54e4073941056..f28671f7869fc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -355,3 +356,4 @@ object DecisionTreeExample { println(s" Root mean squared error (RMSE): $RMSE") } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 7b8cc21ed8982..78f31b4ffe56a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -181,3 +182,4 @@ private class MyLogisticRegressionModel( copyValues(new MyLogisticRegressionModel(uid, weights), extra) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index 33905277c7341..f4a15f806ea81 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -236,3 +237,4 @@ object GBTExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala index b54466fd48bc5..b73299fb12d3f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -140,3 +141,4 @@ object LinearRegressionExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala index 3cf193f353fbc..7682557127b51 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -157,3 +158,4 @@ object LogisticRegressionExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index 25f21113bf622..cd411397a4b9d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scopt.OptionParser @@ -178,3 +179,4 @@ object MovieLensALS { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index 6927eb8f275cf..bab31f585b0ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO} @@ -183,3 +184,4 @@ object OneVsRestExample { (NANO.toSeconds(t1 - t0), result) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index 9f7cad68a4594..109178f4137b2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.collection.mutable @@ -244,3 +245,4 @@ object RandomForestExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index a0561e2573fc9..58d7b67674ff7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} @@ -100,3 +101,4 @@ object SimpleParamsExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index 1324b066c30c3..960280137cbf9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.ml import scala.beans.BeanInfo @@ -89,3 +90,4 @@ object SimpleTextClassificationPipeline { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index a113653810b93..1a4016f76c2ad 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -153,3 +154,4 @@ object BinaryClassification { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala index e49129c4e7844..026d4ecc6d10a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -91,3 +92,4 @@ object Correlations { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala index cb1abbd18fd4d..69988cc1b9334 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -106,3 +107,4 @@ object CosineSimilarity { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index 520893b26d595..dc13f82488af7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import java.io.File @@ -119,3 +120,4 @@ object DatasetExample { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 3381941673db8..57ffe3dd2524f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scala.language.reflectiveCalls @@ -368,3 +369,4 @@ object DecisionTreeRunner { } // scalastyle:on structural.type } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index f8c71ccabc43b..1fce4ba7efd60 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -65,3 +66,4 @@ object DenseGaussianMixture { println() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala index 14cc5cbb679c5..380d85d60e7b4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -107,3 +108,4 @@ object DenseKMeans { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index 13f24a1e59610..14b930550d554 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -80,3 +81,4 @@ object FPGrowthExample { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index 7416fb5a40848..e16a6bf033574 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -145,3 +146,4 @@ object GradientBoostedTreesRunner { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 31d629f853161..75b0f69cf91aa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import java.text.BreakIterator @@ -302,3 +303,4 @@ private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Se } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 6a456ba7ec07b..8878061a0970b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -134,3 +135,4 @@ object LinearRegression { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 99588b0984ab2..e43a6f2864c73 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scala.collection.mutable @@ -189,3 +190,4 @@ object MovieLensALS { math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).mean()) } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala index 6e4e2d07f284b..5f839c75dd581 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import scopt.OptionParser @@ -97,3 +98,4 @@ object MultivariateSummarizer { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 6d8b806569dfd..0723223954610 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -154,4 +155,4 @@ object PowerIterationClusteringExample { coeff * math.exp(expCoeff * ssquares) } } - +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala index 924b586e3af99..bee85ba0f9969 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.random.RandomRDDs @@ -58,3 +59,4 @@ object RandomRDDGeneration { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala index 663c12734af68..6963f43e082c4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.util.MLUtils @@ -125,3 +126,4 @@ object SampledRDDs { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index f1ff4e6911f5e..f81fc292a3bd1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} @@ -100,3 +101,4 @@ object SparseNaiveBayes { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala index 8bb12d2ee9ed2..af03724a8ac62 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.SparkConf @@ -75,3 +76,4 @@ object StreamingKMeansExample { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index 1a95048bbfe2d..b4a5dca031abd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors @@ -69,3 +70,4 @@ object StreamingLinearRegression { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala index e1998099c2d78..b42f4cb5f9338 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors @@ -71,3 +72,4 @@ object StreamingLogisticRegression { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala index 3cd9cb743e309..464fbd385ab5d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -58,3 +59,4 @@ object TallSkinnyPCA { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala index 4d6690318615a..65b4bc46f0266 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} @@ -58,3 +59,4 @@ object TallSkinnySVD { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index b11e32047dc34..2cc56f04e5c1f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} @@ -73,3 +74,4 @@ object RDDRelation { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index b7ba60ec28155..bf40bd1ef13df 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.sql.hive import com.google.common.io.{ByteStreams, Files} @@ -77,3 +78,4 @@ object HiveFromSpark { sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index 016de4c63d1d2..e9c9907198769 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import scala.collection.mutable.LinkedList @@ -170,3 +171,4 @@ object ActorWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 30269a7ccae97..28e9bf520e568 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.io.{InputStreamReader, BufferedReader, InputStream} @@ -100,3 +101,4 @@ class CustomReceiver(host: String, port: Int) } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index fbe394de4a179..bd78526f8c299 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import kafka.serializer.StringDecoder @@ -70,3 +71,4 @@ object DirectKafkaWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala index 20e7df7c45b1b..91e52e4eff5a7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -66,3 +67,4 @@ object FlumeEventCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala index 1cc8c8d5c23b6..2bdbc37e2a289 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -65,3 +66,4 @@ object FlumePollingEventCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala index 4b4667fec44e6..1f282d437dc38 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -53,3 +54,4 @@ object HdfsWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index 60416ee343544..b40d17e9c2fa3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.util.HashMap @@ -101,3 +102,4 @@ object KafkaWordCountProducer { } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index 813c8554f5193..d772ae309f40d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.eclipse.paho.client.mqttv3._ @@ -96,8 +97,10 @@ object MQTTWordCount { def main(args: Array[String]) { if (args.length < 2) { + // scalastyle:off println System.err.println( "Usage: MQTTWordCount ") + // scalastyle:on println System.exit(1) } @@ -113,3 +116,4 @@ object MQTTWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala index 2cd8073dada14..9a57fe286d1ae 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -57,3 +58,4 @@ object NetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index a9aaa445bccb6..5322929d177b4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -58,3 +59,4 @@ object RawNetworkGrep { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 751b30ea15782..9916882e4f94a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.io.File @@ -108,3 +109,4 @@ object RecoverableNetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala index 5a6b9216a3fbc..ed617754cbf1c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -99,3 +100,4 @@ object SQLContextSingleton { instance } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 345d0bc441351..02ba1c2eed0f7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.SparkConf @@ -78,3 +79,4 @@ object StatefulNetworkWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala index c10de84a80ffe..825c671a929b1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import com.twitter.algebird._ @@ -113,3 +114,4 @@ object TwitterAlgebirdCMS { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala index 62db5e663b8af..49826ede70418 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import com.twitter.algebird.HyperLogLogMonoid @@ -90,3 +91,4 @@ object TwitterAlgebirdHLL { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala index f253d75b279f7..49cee1b43c2dc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -82,3 +83,4 @@ object TwitterPopularTags { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index e99d1baa72b9f..6ac9a72c37941 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import akka.actor.ActorSystem @@ -97,3 +98,4 @@ object ZeroMQWordCount { ssc.awaitTermination() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 889f052c70263..bea7a47cb2855 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming.clickstream import java.net.ServerSocket @@ -108,3 +109,4 @@ object PageViewGenerator { } } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index fbacaee98690f..ec7d39da8b2e9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming.clickstream import org.apache.spark.SparkContext._ @@ -107,3 +108,4 @@ object PageViewStream { ssc.start() } } +// scalastyle:on println diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 8e1715f6dbb95..5b3c79444aa68 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -111,7 +111,7 @@ class DirectKafkaStreamSuite rdd }.foreachRDD { rdd => for (o <- offsetRanges) { - println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + logInfo(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") } val collected = rdd.mapPartitionsWithIndex { (i, iter) => // For each partition, get size of the range in the partition, diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index be8b62d3cc6ba..de749626ec09c 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off println package org.apache.spark.examples.streaming import java.nio.ByteBuffer @@ -272,3 +273,4 @@ private[streaming] object StreamingExamples extends Logging { } } } +// scalastyle:on println diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index be6b9047d932d..5c07b415cd796 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -66,7 +66,6 @@ private[graphx] object BytecodeUtils { val finder = new MethodInvocationFinder(c.getName, m) getClassReader(c).accept(finder, 0) for (classMethod <- finder.methodsInvoked) { - // println(classMethod) if (classMethod._1 == targetClass && classMethod._2 == targetMethod) { return true } else if (!seen.contains(classMethod)) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 9591c4e9b8f4e..989e226305265 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -33,7 +33,7 @@ import org.apache.spark.graphx.Edge import org.apache.spark.graphx.impl.GraphImpl /** A collection of graph generating functions. */ -object GraphGenerators { +object GraphGenerators extends Logging { val RMATa = 0.45 val RMATb = 0.15 @@ -142,7 +142,7 @@ object GraphGenerators { var edges: Set[Edge[Int]] = Set() while (edges.size < numEdges) { if (edges.size % 100 == 0) { - println(edges.size + " edges") + logDebug(edges.size + " edges") } edges += addEdge(numVertices) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala index 186d0cc2a977b..61e44dcab578c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.graphx.util import org.apache.spark.SparkFunSuite +// scalastyle:off println class BytecodeUtilsSuite extends SparkFunSuite { import BytecodeUtilsSuite.TestClass @@ -102,6 +103,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { private val c = {e: TestClass => println(e.baz)} } +// scalastyle:on println object BytecodeUtilsSuite { class TestClass(val foo: Int, val bar: Long) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala index 6eaebaf7dba9f..e6bcff48b022c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala @@ -64,8 +64,10 @@ object KMeansDataGenerator { def main(args: Array[String]) { if (args.length < 6) { + // scalastyle:off println println("Usage: KMeansGenerator " + " []") + // scalastyle:on println System.exit(1) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index b4e33c98ba7e5..87eeb5db05d26 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -153,8 +153,10 @@ object LinearDataGenerator { def main(args: Array[String]) { if (args.length < 2) { + // scalastyle:off println println("Usage: LinearDataGenerator " + " [num_examples] [num_features] [num_partitions]") + // scalastyle:on println System.exit(1) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala index 9d802678c4a77..c09cbe69bb971 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -64,8 +64,10 @@ object LogisticRegressionDataGenerator { def main(args: Array[String]) { if (args.length != 5) { + // scalastyle:off println println("Usage: LogisticRegressionGenerator " + " ") + // scalastyle:on println System.exit(1) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index bd73a866c8a82..16f430599a515 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -55,8 +55,10 @@ import org.apache.spark.rdd.RDD object MFDataGenerator { def main(args: Array[String]) { if (args.length < 2) { + // scalastyle:off println println("Usage: MFDataGenerator " + " [m] [n] [rank] [trainSampFact] [noise] [sigma] [test] [testSampFact]") + // scalastyle:on println System.exit(1) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala index a8e30cc9d730c..ad20b7694a779 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -37,8 +37,10 @@ object SVMDataGenerator { def main(args: Array[String]) { if (args.length < 2) { + // scalastyle:off println println("Usage: SVMGenerator " + " [num_examples] [num_features] [num_partitions]") + // scalastyle:on println System.exit(1) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 8c85c96d5c6d8..03120c828ca96 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import scala.beans.{BeanInfo, BeanProperty} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} @@ -27,7 +27,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { import VectorIndexerSuite.FeatureData @@ -113,11 +113,11 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { model.transform(sparsePoints1) // should work intercept[SparkException] { model.transform(densePoints2).collect() - println("Did not throw error when fit, transform were called on vectors of different lengths") + logInfo("Did not throw error when fit, transform were called on vectors of different lengths") } intercept[SparkException] { vectorIndexer.fit(badPoints) - println("Did not throw error when fitting vectors of different lengths in same RDD.") + logInfo("Did not throw error when fitting vectors of different lengths in same RDD.") } } @@ -196,7 +196,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { } } catch { case e: org.scalatest.exceptions.TestFailedException => - println(errMsg) + logError(errMsg) throw e } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index c4ae0a16f7c04..178d95a7b94ec 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -21,10 +21,10 @@ import scala.util.Random import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.mllib.util.TestingUtils._ -class VectorsSuite extends SparkFunSuite { +class VectorsSuite extends SparkFunSuite with Logging { val arr = Array(0.1, 0.0, 0.3, 0.4) val n = 4 @@ -142,7 +142,7 @@ class VectorsSuite extends SparkFunSuite { malformatted.foreach { s => intercept[SparkException] { Vectors.parse(s) - println(s"Didn't detect malformatted string $s.") + logInfo(s"Didn't detect malformatted string $s.") } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index c292ced75e870..c3eeda012571c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.mllib.stat import breeze.linalg.{DenseMatrix => BDM, Matrix => BM} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation, SpearmanCorrelation} import org.apache.spark.mllib.util.MLlibTestSparkContext -class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext { +class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { // test input data val xData = Array(1.0, 0.0, -2.0) @@ -146,7 +146,7 @@ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext { def matrixApproxEqual(A: BM[Double], B: BM[Double], threshold: Double = 1e-6): Boolean = { for (i <- 0 until A.rows; j <- 0 until A.cols) { if (!approxEqual(A(i, j), B(i, j), threshold)) { - println("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j)) + logInfo("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j)) return false } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 84dd3b342d4c0..2521b3342181a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree -import org.apache.spark.SparkFunSuite +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} @@ -31,7 +31,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[GradientBoostedTrees]]. */ -class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext { +class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { test("Regression with continuous features: SquaredError") { GradientBoostedTreesSuite.testCombinations.foreach { @@ -50,7 +50,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06) } catch { case e: java.lang.AssertionError => - println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + s" subsamplingRate=$subsamplingRate") throw e } @@ -80,7 +80,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.85, "mae") } catch { case e: java.lang.AssertionError => - println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + s" subsamplingRate=$subsamplingRate") throw e } @@ -111,7 +111,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext EnsembleTestHelper.validateClassifier(gbt, GradientBoostedTreesSuite.data, 0.9) } catch { case e: java.lang.AssertionError => - println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + s" subsamplingRate=$subsamplingRate") throw e } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala index fa4f74d71b7e7..16d7c3ab39b03 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala @@ -33,7 +33,7 @@ class NumericParserSuite extends SparkFunSuite { malformatted.foreach { s => intercept[SparkException] { NumericParser.parse(s) - println(s"Didn't detect malformatted string $s.") + throw new RuntimeException(s"Didn't detect malformatted string $s.") } } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 3408c6d51ed4c..4291b0be2a616 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -69,6 +69,7 @@ object SparkBuild extends PomBuild { import scala.collection.mutable var isAlphaYarn = false var profiles: mutable.Seq[String] = mutable.Seq("sbt") + // scalastyle:off println if (Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined) { println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pspark-ganglia-lgpl flag.") profiles ++= Seq("spark-ganglia-lgpl") @@ -88,6 +89,7 @@ object SparkBuild extends PomBuild { println("NOTE: SPARK_YARN is deprecated, please use -Pyarn flag.") profiles ++= Seq("yarn") } + // scalastyle:on println profiles } @@ -96,8 +98,10 @@ object SparkBuild extends PomBuild { case None => backwardCompatibility case Some(v) => if (backwardCompatibility.nonEmpty) + // scalastyle:off println println("Note: We ignore environment variables, when use of profile is detected in " + "conjunction with environment variable.") + // scalastyle:on println v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala index 6480e2d24e044..24fbbc12c08da 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala @@ -39,6 +39,8 @@ class SparkCommandLine(args: List[String], override val settings: Settings) } def this(args: List[String]) { + // scalastyle:off println this(args, str => Console.println("Error: " + str)) + // scalastyle:on println } } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 2b235525250c2..8f7f9074d3f03 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -1101,7 +1101,9 @@ object SparkILoop extends Logging { val s = super.readLine() // helping out by printing the line being interpreted. if (s != null) + // scalastyle:off println output.println(s) + // scalastyle:on println s } } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index 05faef8786d2c..bd3314d94eed6 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -80,11 +80,13 @@ private[repl] trait SparkILoopInit { if (!initIsComplete) withLock { while (!initIsComplete) initLoopCondition.await() } if (initError != null) { + // scalastyle:off println println(""" |Failed to initialize the REPL due to an unexpected error. |This is a bug, please, report it along with the error diagnostics printed below. |%s.""".stripMargin.format(initError) ) + // scalastyle:on println false } else true } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 35fb625645022..8791618bd355e 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -1761,7 +1761,9 @@ object SparkIMain { if (intp.totalSilence) () else super.printMessage(msg) } + // scalastyle:off println else Console.println(msg) + // scalastyle:on println } } } diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 7a5e94da5cbf3..3c90287249497 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -943,7 +943,9 @@ object SparkILoop { val s = super.readLine() // helping out by printing the line being interpreted. if (s != null) + // scalastyle:off println output.println(s) + // scalastyle:on println s } } diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 1cb910f376060..56c009a4e38e7 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -129,7 +129,9 @@ class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings } private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" private val logScope = scala.sys.props contains "scala.repl.scope" + // scalastyle:off println private def scopelog(msg: String) = if (logScope) Console.err.println(msg) + // scalastyle:on println // argument is a thunk to execute after init is done def initialize(postInitSignal: => Unit) { @@ -1297,8 +1299,10 @@ class SparkISettings(intp: SparkIMain) { def deprecation_=(x: Boolean) = { val old = intp.settings.deprecation.value intp.settings.deprecation.value = x + // scalastyle:off println if (!old && x) println("Enabled -deprecation output.") else if (old && !x) println("Disabled -deprecation output.") + // scalastyle:on println } def deprecation: Boolean = intp.settings.deprecation.value diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkReplReporter.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkReplReporter.scala index 0711ed4871bb6..272f81eca92c1 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkReplReporter.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkReplReporter.scala @@ -42,7 +42,9 @@ class SparkReplReporter(intp: SparkIMain) extends ConsoleReporter(intp.settings, } else super.printMessage(msg) } + // scalastyle:off println else Console.println("[init] " + msg) + // scalastyle:on println } override def displayPrompt() { diff --git a/scalastyle-config.xml b/scalastyle-config.xml index d6f927b6fa803..49611703798e8 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -141,12 +141,8 @@ This file is divided into 3 sections: Tests must extend org.apache.spark.SparkFunSuite instead. - - - - - - + + ^println$ + + + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 7f1b12cdd5800..606fecbe06e47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -67,8 +67,10 @@ package object codegen { outfile.write(generatedBytes) outfile.close() + // scalastyle:off println println( s"javap -p -v -classpath ${dumpDirectory.getCanonicalPath} ${generatedClass.getName}".!!) + // scalastyle:on println } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 2f545bb432165..b89e3382f06a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -154,7 +154,9 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def schemaString: String = schema.treeString /** Prints out the schema in the tree format */ + // scalastyle:off println def printSchema(): Unit = println(schemaString) + // scalastyle:on println /** * A prefix string used when printing the plan. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 07054166a5e88..71293475ca0f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -124,7 +124,9 @@ package object util { val startTime = System.nanoTime() val ret = f val endTime = System.nanoTime() + // scalastyle:off println println(s"${(endTime - startTime).toDouble / 1000000}ms") + // scalastyle:on println ret } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index e0b8ff91786a7..b8097403ec3cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -250,7 +250,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru builder.toString() } + // scalastyle:off println def printTreeString(): Unit = println(treeString) + // scalastyle:on println private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { fields.foreach(field => field.buildFormattedString(prefix, builder)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index f201c8ea8a110..10250264625b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -860,11 +860,13 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.3.0 */ def explain(extended: Boolean): Unit = { + // scalastyle:off println if (extended) { println(expr) } else { println(expr.prettyString) } + // scalastyle:on println } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d7966651b1948..830fba35bb7bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -308,7 +308,9 @@ class DataFrame private[sql]( * @group basic * @since 1.3.0 */ + // scalastyle:off println def printSchema(): Unit = println(schema.treeString) + // scalastyle:on println /** * Prints the plans (logical and physical) to the console for debugging purposes. @@ -319,7 +321,9 @@ class DataFrame private[sql]( ExplainCommand( queryExecution.logical, extended = extended).queryExecution.executedPlan.executeCollect().map { + // scalastyle:off println r => println(r.getString(0)) + // scalastyle:on println } } @@ -392,7 +396,9 @@ class DataFrame private[sql]( * @group action * @since 1.5.0 */ + // scalastyle:off println def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate)) + // scalastyle:on println /** * Returns a [[DataFrameNaFunctions]] for working with missing data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 2964edac1aba2..e6081cb05bc2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -24,7 +24,7 @@ import org.apache.spark.unsafe.types.UTF8String import scala.collection.mutable.HashSet -import org.apache.spark.{AccumulatorParam, Accumulator} +import org.apache.spark.{AccumulatorParam, Accumulator, Logging} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef @@ -57,7 +57,7 @@ package object debug { * Augments [[DataFrame]]s with debug methods. */ @DeveloperApi - implicit class DebugQuery(query: DataFrame) { + implicit class DebugQuery(query: DataFrame) extends Logging { def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[TreeNodeRef]() @@ -66,7 +66,7 @@ package object debug { visited += new TreeNodeRef(s) DebugNode(s) } - println(s"Results returned: ${debugPlan.execute().count()}") + logDebug(s"Results returned: ${debugPlan.execute().count()}") debugPlan.foreach { case d: DebugNode => d.dumpStats() case _ => @@ -82,11 +82,11 @@ package object debug { TypeCheck(s) } try { - println(s"Results returned: ${debugPlan.execute().count()}") + logDebug(s"Results returned: ${debugPlan.execute().count()}") } catch { case e: Exception => def unwrap(e: Throwable): Throwable = if (e.getCause == null) e else unwrap(e.getCause) - println(s"Deepest Error: ${unwrap(e)}") + logDebug(s"Deepest Error: ${unwrap(e)}") } } } @@ -119,11 +119,11 @@ package object debug { val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics()) def dumpStats(): Unit = { - println(s"== ${child.simpleString} ==") - println(s"Tuples output: ${tupleCount.value}") + logDebug(s"== ${child.simpleString} ==") + logDebug(s"Tuples output: ${tupleCount.value}") child.output.zip(columnStats).foreach { case(attr, metric) => val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}") - println(s" ${attr.name} ${attr.dataType}: $actualDataTypes") + logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes") } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 039cfa40d26b3..f66a17b20915f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -40,7 +40,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.Utils -private[hive] object SparkSQLCLIDriver { +private[hive] object SparkSQLCLIDriver extends Logging { private var prompt = "spark-sql" private var continuedPrompt = "".padTo(prompt.length, ' ') private var transport: TSocket = _ @@ -164,7 +164,7 @@ private[hive] object SparkSQLCLIDriver { } } catch { case e: FileNotFoundException => - System.err.println(s"Could not open input file for reading. (${e.getMessage})") + logError(s"Could not open input file for reading. (${e.getMessage})") System.exit(3) } @@ -180,14 +180,14 @@ private[hive] object SparkSQLCLIDriver { val historyFile = historyDirectory + File.separator + ".hivehistory" reader.setHistory(new History(new File(historyFile))) } else { - System.err.println("WARNING: Directory for Hive history file: " + historyDirectory + + logWarning("WARNING: Directory for Hive history file: " + historyDirectory + " does not exist. History will not be available during this session.") } } catch { case e: Exception => - System.err.println("WARNING: Encountered an error while trying to initialize Hive's " + + logWarning("WARNING: Encountered an error while trying to initialize Hive's " + "history file. History will not be available during this session.") - System.err.println(e.getMessage) + logWarning(e.getMessage) } val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") @@ -270,6 +270,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { val proc: CommandProcessor = CommandProcessorFactory.get(Array(tokens(0)), hconf) if (proc != null) { + // scalastyle:off println if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] || proc.isInstanceOf[AddResourceProcessor]) { val driver = new SparkSQLDriver @@ -336,6 +337,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } ret = proc.run(cmd_1).getResponseCode } + // scalastyle:on println } ret } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index bbc39b892b79e..4684d48aff889 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} +import org.apache.spark.Logging import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ @@ -65,12 +66,12 @@ private[hive] class HiveQLDialect extends ParserDialect { * * @since 1.0.0 */ -class HiveContext(sc: SparkContext) extends SQLContext(sc) { +class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { self => import HiveContext._ - println("create HiveContext") + logDebug("create HiveContext") /** * When true, enables an experimental feature where metastore tables that use the parquet SerDe diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 2de7a99c122fd..7fc517b646b20 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -73,7 +74,7 @@ private[hive] case class CreateTableAsSelect( } /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ -private[hive] object HiveQl { +private[hive] object HiveQl extends Logging { protected val nativeCommands = Seq( "TOK_ALTERDATABASE_OWNER", "TOK_ALTERDATABASE_PROPERTIES", @@ -186,7 +187,7 @@ private[hive] object HiveQl { .map(ast => Option(ast).map(_.transform(rule)).orNull)) } catch { case e: Exception => - println(dumpTree(n)) + logError(dumpTree(n).toString) throw e } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index cbd2bf6b5eede..9d83ca6c113dc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -360,7 +360,9 @@ private[hive] class ClientWrapper( case _ => if (state.out != null) { + // scalastyle:off println state.out.println(tokens(0) + " " + cmd_1) + // scalastyle:on println } Seq(proc.run(cmd_1).getResponseCode.toString) } diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala index 0e428ba1d7456..2590040f2ec1c 100644 --- a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala +++ b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.hive.HiveContext */ object Main { def main(args: Array[String]) { + // scalastyle:off println println("Running regression test for SPARK-8489.") val sc = new SparkContext("local", "testing") val hc = new HiveContext(sc) @@ -38,6 +39,7 @@ object Main { val df = hc.createDataFrame(Seq(MyCoolClass("1", "2", "3"))) df.collect() println("Regression test for SPARK-8489 success!") + // scalastyle:on println sc.stop() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index e9bb32667936c..983c013bcf86a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.hive -import org.apache.spark.SparkFunSuite +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.ExamplePointUDT import org.apache.spark.sql.types.StructType -class HiveMetastoreCatalogSuite extends SparkFunSuite { +class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { test("struct field should accept underscore in sub-column name") { val metastr = "struct" @@ -41,7 +41,7 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite { test("duplicated metastore relations") { import TestHive.implicits._ val df = TestHive.sql("SELECT * FROM src") - println(df.queryExecution) + logInfo(df.queryExecution.toString) df.as('a).join(df.as('b), $"a.key" === $"b.key") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index a38ed23b5cf9a..917900e5f46dc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -90,8 +90,10 @@ class HiveSparkSubmitSuite "SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome ).run(ProcessLogger( + // scalastyle:off println (line: String) => { println(s"out> $line") }, (line: String) => { println(s"err> $line") } + // scalastyle:on println )) try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index aa5dbe2db6903..508695919e9a7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -86,8 +86,6 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { val message = intercept[QueryExecutionException] { sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") }.getMessage - - println("message!!!!" + message) } test("Double create does not fail when allowExisting = true") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index cc294bc3e8bc3..d910af22c3dd1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.hadoop.fs.Path import org.apache.hadoop.mapred.InvalidInputException +import org.apache.spark.Logging import org.apache.spark.sql._ import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive @@ -40,7 +41,8 @@ import org.apache.spark.util.Utils /** * Tests for persisting tables created though the data sources API into the metastore. */ -class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { +class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll + with Logging { override val sqlContext = TestHive var jsonFilePath: String = _ @@ -415,7 +417,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA |) """.stripMargin) - sql("DROP TABLE jsonTable").collect().foreach(println) + sql("DROP TABLE jsonTable").collect().foreach(i => logInfo(i.toString)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index eaaa88e17002b..1bde5922b5278 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -315,7 +315,6 @@ class PairUDF extends GenericUDF { ) override def evaluate(args: Array[DeferredObject]): AnyRef = { - println("Type = %s".format(args(0).getClass.getName)) Integer.valueOf(args(0).get.asInstanceOf[TestPair].entry._2) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 192aa6a139bcb..1da0b0a54df07 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -720,12 +720,14 @@ abstract class DStream[T: ClassTag] ( def foreachFunc: (RDD[T], Time) => Unit = { (rdd: RDD[T], time: Time) => { val firstNum = rdd.take(num + 1) + // scalastyle:off println println("-------------------------------------------") println("Time: " + time) println("-------------------------------------------") firstNum.take(num).foreach(println) if (firstNum.length > num) println("...") println() + // scalastyle:on println } } new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala index ca2f319f174a2..6addb96752038 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -35,7 +35,9 @@ private[streaming] object RawTextSender extends Logging { def main(args: Array[String]) { if (args.length != 4) { + // scalastyle:off println System.err.println("Usage: RawTextSender ") + // scalastyle:on println System.exit(1) } // Parse the arguments using a pattern match diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index c8eef833eb431..dd32ad5ad811d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -106,7 +106,7 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: } private[streaming] -object RecurringTimer { +object RecurringTimer extends Logging { def main(args: Array[String]) { var lastRecurTime = 0L @@ -114,7 +114,7 @@ object RecurringTimer { def onRecur(time: Long) { val currentTime = System.currentTimeMillis() - println("" + currentTime + ": " + (currentTime - lastRecurTime)) + logInfo("" + currentTime + ": " + (currentTime - lastRecurTime)) lastRecurTime = currentTime } val timer = new RecurringTimer(new SystemClock(), period, onRecur, "Test") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index e0f14fd954280..6e9d4431090a2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -43,6 +43,7 @@ object MasterFailureTest extends Logging { @volatile var setupCalled = false def main(args: Array[String]) { + // scalastyle:off println if (args.size < 2) { println( "Usage: MasterFailureTest <# batches> " + @@ -60,6 +61,7 @@ object MasterFailureTest extends Logging { testUpdateStateByKey(directory, numBatches, batchDuration) println("\n\nSUCCESS\n\n") + // scalastyle:on println } def testMap(directory: String, numBatches: Int, batchDuration: Duration) { @@ -291,10 +293,12 @@ object MasterFailureTest extends Logging { } // Log the output + // scalastyle:off println println("Expected output, size = " + expectedOutput.size) println(expectedOutput.mkString("[", ",", "]")) println("Output, size = " + output.size) println(output.mkString("[", ",", "]")) + // scalastyle:on println // Match the output with the expected output output.foreach(o => diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala index 7865b06c2e3c2..a2dbae149f311 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -76,7 +76,6 @@ class JobGeneratorSuite extends TestSuiteBase { if (time.milliseconds == longBatchTime) { while (waitLatch.getCount() > 0) { waitLatch.await() - println("Await over") } } }) diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 595ded6ae67fa..9483d2b692ab5 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -92,7 +92,9 @@ object GenerateMIMAIgnore { ignoredMembers ++= getAnnotatedOrPackagePrivateMembers(classSymbol) } catch { + // scalastyle:off println case _: Throwable => println("Error instrumenting class:" + className) + // scalastyle:on println } } (ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet) @@ -108,7 +110,9 @@ object GenerateMIMAIgnore { .filter(_.contains("$$")).map(classSymbol.fullName + "." + _) } catch { case t: Throwable => + // scalastyle:off println println("[WARN] Unable to detect inner functions for class:" + classSymbol.fullName) + // scalastyle:on println Seq.empty[String] } } @@ -128,12 +132,14 @@ object GenerateMIMAIgnore { getOrElse(Iterator.empty).mkString("\n") File(".generated-mima-class-excludes") .writeAll(previousContents + privateClasses.mkString("\n")) + // scalastyle:off println println("Created : .generated-mima-class-excludes in current directory.") val previousMembersContents = Try(File(".generated-mima-member-excludes").lines) .getOrElse(Iterator.empty).mkString("\n") File(".generated-mima-member-excludes").writeAll(previousMembersContents + privateMembers.mkString("\n")) println("Created : .generated-mima-member-excludes in current directory.") + // scalastyle:on println } @@ -174,7 +180,9 @@ object GenerateMIMAIgnore { try { classes += Class.forName(entry.replace('/', '.').stripSuffix(".class"), false, classLoader) } catch { + // scalastyle:off println case _: Throwable => println("Unable to load:" + entry) + // scalastyle:on println } } classes diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala index 583823c90c5c6..856ea177a9a10 100644 --- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala +++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala @@ -323,11 +323,14 @@ object JavaAPICompletenessChecker { val missingMethods = javaEquivalents -- javaMethods for (method <- missingMethods) { + // scalastyle:off println println(method) + // scalastyle:on println } } def main(args: Array[String]) { + // scalastyle:off println println("Missing RDD methods") printMissingMethods(classOf[RDD[_]], classOf[JavaRDD[_]]) println() @@ -359,5 +362,6 @@ object JavaAPICompletenessChecker { println("Missing PairDStream methods") printMissingMethods(classOf[PairDStreamFunctions[_, _]], classOf[JavaPairDStream[_, _]]) println() + // scalastyle:on println } } diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index baa97616eaff3..0dc2861253f17 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -85,7 +85,9 @@ object StoragePerfTester { latch.countDown() } catch { case e: Exception => + // scalastyle:off println println("Exception in child thread: " + e + " " + e.getMessage) + // scalastyle:on println System.exit(1) } } @@ -97,9 +99,11 @@ object StoragePerfTester { val bytesPerSecond = totalBytes.get() / time val bytesPerFile = (totalBytes.get() / (numOutputSplits * numMaps.toDouble)).toLong + // scalastyle:off println System.err.println("files_total\t\t%s".format(numMaps * numOutputSplits)) System.err.println("bytes_per_file\t\t%s".format(Utils.bytesToString(bytesPerFile))) System.err.println("agg_throughput\t\t%s/s".format(Utils.bytesToString(bytesPerSecond.toLong))) + // scalastyle:on println executor.shutdown() sc.stop() diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 68e9f6b4db7f4..37f793763367e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -85,7 +85,9 @@ class ApplicationMasterArguments(val args: Array[String]) { } if (primaryPyFile != null && primaryRFile != null) { + // scalastyle:off println System.err.println("Cannot have primary-py-file and primary-r-file at the same time") + // scalastyle:on println System.exit(-1) } @@ -93,6 +95,7 @@ class ApplicationMasterArguments(val args: Array[String]) { } def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + // scalastyle:off println if (unknownParam != null) { System.err.println("Unknown/unsupported param " + unknownParam) } @@ -111,6 +114,7 @@ class ApplicationMasterArguments(val args: Array[String]) { | --executor-cores NUM Number of cores for the executors (Default: 1) | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) """.stripMargin) + // scalastyle:on println System.exit(exitCode) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 4d52ae774ea00..f0af6f875f523 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -938,7 +938,7 @@ private[spark] class Client( object Client extends Logging { def main(argStrings: Array[String]) { if (!sys.props.contains("SPARK_SUBMIT")) { - println("WARNING: This client is deprecated and will be removed in a " + + logWarning("WARNING: This client is deprecated and will be removed in a " + "future version of Spark. Use ./bin/spark-submit with \"--master yarn\"") } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 19d1bbff9993f..20d63d40cf605 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -123,6 +123,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) throw new SparkException("Executor cores must not be less than " + "spark.task.cpus.") } + // scalastyle:off println if (isClusterMode) { for (key <- Seq(amMemKey, amMemOverheadKey, amCoresKey)) { if (sparkConf.contains(key)) { @@ -144,11 +145,13 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) .map(_.toInt) .foreach { cores => amCores = cores } } + // scalastyle:on println } private def parseArgs(inputArgs: List[String]): Unit = { var args = inputArgs + // scalastyle:off println while (!args.isEmpty) { args match { case ("--jar") :: value :: tail => @@ -253,6 +256,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) throw new IllegalArgumentException(getUsageMessage(args)) } } + // scalastyle:on println if (primaryPyFile != null && primaryRFile != null) { throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" + diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 335e966519c7c..547863d9a0739 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -328,12 +328,14 @@ private object YarnClusterDriver extends Logging with Matchers { def main(args: Array[String]): Unit = { if (args.length != 1) { + // scalastyle:off println System.err.println( s""" |Invalid command line: ${args.mkString(" ")} | |Usage: YarnClusterDriver [result file] """.stripMargin) + // scalastyle:on println System.exit(1) } @@ -386,12 +388,14 @@ private object YarnClasspathTest { def main(args: Array[String]): Unit = { if (args.length != 2) { + // scalastyle:off println System.err.println( s""" |Invalid command line: ${args.mkString(" ")} | |Usage: YarnClasspathTest [driver result file] [executor result file] """.stripMargin) + // scalastyle:on println System.exit(1) } From 11e22b74a080ea58fb9410b5cc6fa4c03f9198f2 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Fri, 10 Jul 2015 16:22:49 +0100 Subject: [PATCH 103/149] [SPARK-7944] [SPARK-8013] Remove most of the Spark REPL fork for Scala 2.11 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR removes most of the code in the Spark REPL for Scala 2.11 and leaves just a couple of overridden methods in `SparkILoop` in order to: - change welcome message - restrict available commands (like `:power`) - initialize Spark context The two codebases have diverged and it's extremely hard to backport fixes from the upstream REPL. This somewhat radical step is absolutely necessary in order to fix other REPL tickets (like SPARK-8013 - Hive Thrift server for 2.11). BTW, the Scala REPL has fixed the serialization-unfriendly wrappers thanks to ScrapCodes's work in [#4522](https://github.com/scala/scala/pull/4522) All tests pass and I tried the `spark-shell` on our Mesos cluster with some simple jobs (including with additional jars), everything looked good. As soon as Scala 2.11.7 is out we need to upgrade and get a shaded `jline` dependency, clearing the way for SPARK-8013. /cc pwendell Author: Iulian Dragos Closes #6903 from dragos/issue/no-spark-repl-fork and squashes the following commits: c596c6f [Iulian Dragos] Merge branch 'master' into issue/no-spark-repl-fork 2b1a305 [Iulian Dragos] Removed spaces around multiple imports. 0ce67a6 [Iulian Dragos] Remove -verbose flag for java compiler (added by mistake in an earlier commit). 10edaf9 [Iulian Dragos] Keep the jline dependency only in the 2.10 build. 529293b [Iulian Dragos] Add back Spark REPL files to rat-excludes, since they are part of the 2.10 real. d85370d [Iulian Dragos] Remove jline dependency from the Spark REPL. b541930 [Iulian Dragos] Merge branch 'master' into issue/no-spark-repl-fork 2b15962 [Iulian Dragos] Change jline dependency and bump Scala version. b300183 [Iulian Dragos] Rename package and add license on top of the file, remove files from rat-excludes and removed `-Yrepl-sync` per reviewer’s request. 9d46d85 [Iulian Dragos] Fix SPARK-7944. abcc7cb [Iulian Dragos] Remove the REPL forked code. --- pom.xml | 18 +- repl/pom.xml | 19 +- .../scala/org/apache/spark/repl/Main.scala | 16 +- .../apache/spark/repl/SparkExprTyper.scala | 86 -- .../org/apache/spark/repl/SparkILoop.scala | 971 +----------- .../org/apache/spark/repl/SparkIMain.scala | 1323 ----------------- .../org/apache/spark/repl/SparkImports.scala | 201 --- .../spark/repl/SparkJLineCompletion.scala | 350 ----- .../spark/repl/SparkMemberHandlers.scala | 221 --- .../apache/spark/repl/SparkReplReporter.scala | 55 - .../org/apache/spark/repl/ReplSuite.scala | 11 +- 11 files changed, 90 insertions(+), 3181 deletions(-) delete mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala delete mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala delete mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkImports.scala delete mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala delete mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala delete mode 100644 repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkReplReporter.scala diff --git a/pom.xml b/pom.xml index 172fdef4c73da..c2ebc1a11e770 100644 --- a/pom.xml +++ b/pom.xml @@ -341,11 +341,6 @@ - - ${jline.groupid} - jline - ${jline.version} - com.twitter chill_${scala.binary.version} @@ -1826,6 +1821,15 @@ ${scala.version} org.scala-lang + + + + ${jline.groupid} + jline + ${jline.version} + + + @@ -1844,10 +1848,8 @@ scala-2.11 - 2.11.6 + 2.11.7 2.11 - 2.12.1 - jline diff --git a/repl/pom.xml b/repl/pom.xml index 370b2bc2fa8ed..70c9bd7c01296 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -38,11 +38,6 @@ - - ${jline.groupid} - jline - ${jline.version} - org.apache.spark spark-core_${scala.binary.version} @@ -161,6 +156,20 @@ + + scala-2.10 + + !scala-2.11 + + + + ${jline.groupid} + jline + ${jline.version} + + + + scala-2.11 diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index f4f4b626988e9..eed4a379afa60 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -17,13 +17,14 @@ package org.apache.spark.repl +import java.io.File + +import scala.tools.nsc.Settings + import org.apache.spark.util.Utils import org.apache.spark._ import org.apache.spark.sql.SQLContext -import scala.tools.nsc.Settings -import scala.tools.nsc.interpreter.SparkILoop - object Main extends Logging { val conf = new SparkConf() @@ -32,7 +33,8 @@ object Main extends Logging { val outputDir = Utils.createTempDir(rootDir) val s = new Settings() s.processArguments(List("-Yrepl-class-based", - "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true) + "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", + "-classpath", getAddedJars.mkString(File.pathSeparator)), true) val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) var sparkContext: SparkContext = _ var sqlContext: SQLContext = _ @@ -48,7 +50,6 @@ object Main extends Logging { Option(sparkContext).map(_.stop) } - def getAddedJars: Array[String] = { val envJars = sys.env.get("ADD_JARS") if (envJars.isDefined) { @@ -84,10 +85,9 @@ object Main extends Logging { val loader = Utils.getContextOrSparkClassLoader try { sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext]) - .newInstance(sparkContext).asInstanceOf[SQLContext] + .newInstance(sparkContext).asInstanceOf[SQLContext] logInfo("Created sql context (with Hive support)..") - } - catch { + } catch { case _: java.lang.ClassNotFoundException | _: java.lang.NoClassDefFoundError => sqlContext = new SQLContext(sparkContext) logInfo("Created sql context..") diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala deleted file mode 100644 index 8e519fa67f649..0000000000000 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Paul Phillips - */ - -package scala.tools.nsc -package interpreter - -import scala.tools.nsc.ast.parser.Tokens.EOF - -trait SparkExprTyper { - val repl: SparkIMain - - import repl._ - import global.{ reporter => _, Import => _, _ } - import naming.freshInternalVarName - - def symbolOfLine(code: String): Symbol = { - def asExpr(): Symbol = { - val name = freshInternalVarName() - // Typing it with a lazy val would give us the right type, but runs - // into compiler bugs with things like existentials, so we compile it - // behind a def and strip the NullaryMethodType which wraps the expr. - val line = "def " + name + " = " + code - - interpretSynthetic(line) match { - case IR.Success => - val sym0 = symbolOfTerm(name) - // drop NullaryMethodType - sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType) - case _ => NoSymbol - } - } - def asDefn(): Symbol = { - val old = repl.definedSymbolList.toSet - - interpretSynthetic(code) match { - case IR.Success => - repl.definedSymbolList filterNot old match { - case Nil => NoSymbol - case sym :: Nil => sym - case syms => NoSymbol.newOverloaded(NoPrefix, syms) - } - case _ => NoSymbol - } - } - def asError(): Symbol = { - interpretSynthetic(code) - NoSymbol - } - beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError() - } - - private var typeOfExpressionDepth = 0 - def typeOfExpression(expr: String, silent: Boolean = true): Type = { - if (typeOfExpressionDepth > 2) { - repldbg("Terminating typeOfExpression recursion for expression: " + expr) - return NoType - } - typeOfExpressionDepth += 1 - // Don't presently have a good way to suppress undesirable success output - // while letting errors through, so it is first trying it silently: if there - // is an error, and errors are desired, then it re-evaluates non-silently - // to induce the error message. - try beSilentDuring(symbolOfLine(expr).tpe) match { - case NoType if !silent => symbolOfLine(expr).tpe // generate error - case tpe => tpe - } - finally typeOfExpressionDepth -= 1 - } - - // This only works for proper types. - def typeOfTypeString(typeString: String): Type = { - def asProperType(): Option[Type] = { - val name = freshInternalVarName() - val line = "def %s: %s = ???" format (name, typeString) - interpretSynthetic(line) match { - case IR.Success => - val sym0 = symbolOfTerm(name) - Some(sym0.asMethod.returnType) - case _ => None - } - } - beSilentDuring(asProperType()) getOrElse NoType - } -} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 3c90287249497..bf609ff0f65fc 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -1,88 +1,64 @@ -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Alexander Spoon +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -package scala -package tools.nsc -package interpreter +package org.apache.spark.repl -import scala.language.{ implicitConversions, existentials } -import scala.annotation.tailrec -import Predef.{ println => _, _ } -import interpreter.session._ -import StdReplTags._ -import scala.reflect.api.{Mirror, Universe, TypeCreator} -import scala.util.Properties.{ jdkHome, javaVersion, versionString, javaVmName } -import scala.tools.nsc.util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream } -import scala.reflect.{ClassTag, classTag} -import scala.reflect.internal.util.{ BatchSourceFile, ScalaClassLoader } -import ScalaClassLoader._ -import scala.reflect.io.{ File, Directory } -import scala.tools.util._ -import scala.collection.generic.Clearable -import scala.concurrent.{ ExecutionContext, Await, Future, future } -import ExecutionContext.Implicits._ -import java.io.{ BufferedReader, FileReader } +import java.io.{BufferedReader, FileReader} -/** The Scala interactive shell. It provides a read-eval-print loop - * around the Interpreter class. - * After instantiation, clients should call the main() method. - * - * If no in0 is specified, then input will come from the console, and - * the class will attempt to provide input editing feature such as - * input history. - * - * @author Moez A. Abdel-Gawad - * @author Lex Spoon - * @version 1.2 - */ -class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter) - extends AnyRef - with LoopCommands -{ - def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) - def this() = this(None, new JPrintWriter(Console.out, true)) -// -// @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp -// @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: Interpreter): Unit = intp = i - - var in: InteractiveReader = _ // the input stream from which commands come - var settings: Settings = _ - var intp: SparkIMain = _ +import Predef.{println => _, _} +import scala.util.Properties.{jdkHome, javaVersion, versionString, javaVmName} - var globalFuture: Future[Boolean] = _ +import scala.tools.nsc.interpreter.{JPrintWriter, ILoop} +import scala.tools.nsc.Settings +import scala.tools.nsc.util.stringFromStream - protected def asyncMessage(msg: String) { - if (isReplInfo || isReplPower) - echoAndRefresh(msg) - } +/** + * A Spark-specific interactive shell. + */ +class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) + extends ILoop(in0, out) { + def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) + def this() = this(None, new JPrintWriter(Console.out, true)) def initializeSpark() { intp.beQuietDuring { - command( """ + processLine(""" @transient val sc = { val _sc = org.apache.spark.repl.Main.createSparkContext() println("Spark context available as sc.") _sc } """) - command( """ + processLine(""" @transient val sqlContext = { val _sqlContext = org.apache.spark.repl.Main.createSQLContext() println("SQL context available as sqlContext.") _sqlContext } """) - command("import org.apache.spark.SparkContext._") - command("import sqlContext.implicits._") - command("import sqlContext.sql") - command("import org.apache.spark.sql.functions._") + processLine("import org.apache.spark.SparkContext._") + processLine("import sqlContext.implicits._") + processLine("import sqlContext.sql") + processLine("import org.apache.spark.sql.functions._") } } /** Print a welcome message */ - def printWelcome() { + override def printWelcome() { import org.apache.spark.SPARK_VERSION echo("""Welcome to ____ __ @@ -98,877 +74,42 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter) echo("Type :help for more information.") } - override def echoCommandMessage(msg: String) { - intp.reporter printUntruncatedMessage msg - } - - // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals]) - def history = in.history - - // classpath entries added via :cp - var addedClasspath: String = "" - - /** A reverse list of commands to replay if the user requests a :replay */ - var replayCommandStack: List[String] = Nil - - /** A list of commands to replay if the user requests a :replay */ - def replayCommands = replayCommandStack.reverse - - /** Record a command for replay should the user request a :replay */ - def addReplay(cmd: String) = replayCommandStack ::= cmd - - def savingReplayStack[T](body: => T): T = { - val saved = replayCommandStack - try body - finally replayCommandStack = saved - } - def savingReader[T](body: => T): T = { - val saved = in - try body - finally in = saved - } - - /** Close the interpreter and set the var to null. */ - def closeInterpreter() { - if (intp ne null) { - intp.close() - intp = null - } - } - - class SparkILoopInterpreter extends SparkIMain(settings, out) { - outer => - - override lazy val formatting = new Formatting { - def prompt = SparkILoop.this.prompt - } - override protected def parentClassLoader = - settings.explicitParentLoader.getOrElse( classOf[SparkILoop].getClassLoader ) - } - - /** Create a new interpreter. */ - def createInterpreter() { - if (addedClasspath != "") - settings.classpath append addedClasspath - - intp = new SparkILoopInterpreter - } - - /** print a friendly help message */ - def helpCommand(line: String): Result = { - if (line == "") helpSummary() - else uniqueCommand(line) match { - case Some(lc) => echo("\n" + lc.help) - case _ => ambiguousError(line) - } - } - private def helpSummary() = { - val usageWidth = commands map (_.usageMsg.length) max - val formatStr = "%-" + usageWidth + "s %s" - - echo("All commands can be abbreviated, e.g. :he instead of :help.") - - commands foreach { cmd => - echo(formatStr.format(cmd.usageMsg, cmd.help)) - } - } - private def ambiguousError(cmd: String): Result = { - matchingCommands(cmd) match { - case Nil => echo(cmd + ": no such command. Type :help for help.") - case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?") - } - Result(keepRunning = true, None) - } - private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd) - private def uniqueCommand(cmd: String): Option[LoopCommand] = { - // this lets us add commands willy-nilly and only requires enough command to disambiguate - matchingCommands(cmd) match { - case List(x) => Some(x) - // exact match OK even if otherwise appears ambiguous - case xs => xs find (_.name == cmd) - } - } - - /** Show the history */ - lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { - override def usage = "[num]" - def defaultLines = 20 - - def apply(line: String): Result = { - if (history eq NoHistory) - return "No history available." - - val xs = words(line) - val current = history.index - val count = try xs.head.toInt catch { case _: Exception => defaultLines } - val lines = history.asStrings takeRight count - val offset = current - lines.size + 1 - - for ((line, index) <- lines.zipWithIndex) - echo("%3d %s".format(index + offset, line)) - } - } - - // When you know you are most likely breaking into the middle - // of a line being typed. This softens the blow. - protected def echoAndRefresh(msg: String) = { - echo("\n" + msg) - in.redrawLine() - } - protected def echo(msg: String) = { - out println msg - out.flush() - } - - /** Search the history */ - def searchHistory(_cmdline: String) { - val cmdline = _cmdline.toLowerCase - val offset = history.index - history.size + 1 - - for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline) - echo("%d %s".format(index + offset, line)) - } - - private val currentPrompt = Properties.shellPromptString - - /** Prompt to print when awaiting input */ - def prompt = currentPrompt - import LoopCommand.{ cmd, nullary } - /** Standard commands **/ - lazy val standardCommands = List( - cmd("cp", "", "add a jar or directory to the classpath", addClasspath), - cmd("edit", "|", "edit history", editCommand), - cmd("help", "[command]", "print this summary or command-specific help", helpCommand), - historyCommand, - cmd("h?", "", "search the history", searchHistory), - cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand), - //cmd("implicits", "[-v]", "show the implicits in scope", intp.implicitsCommand), - cmd("javap", "", "disassemble a file or class name", javapCommand), - cmd("line", "|", "place line(s) at the end of history", lineCommand), - cmd("load", "", "interpret lines in a file", loadCommand), - cmd("paste", "[-raw] [path]", "enter paste mode or paste a file", pasteCommand), - // nullary("power", "enable power user mode", powerCmd), - nullary("quit", "exit the interpreter", () => Result(keepRunning = false, None)), - nullary("replay", "reset execution and replay all previous commands", replay), - nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand), - cmd("save", "", "save replayable session to a file", saveCommand), - shCommand, - cmd("settings", "[+|-]", "+enable/-disable flags, set compiler options", changeSettings), - nullary("silent", "disable/enable automatic printing of results", verbosity), -// cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), -// cmd("kind", "[-v] ", "display the kind of expression's type", kindCommand), - nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) - ) - - /** Power user commands */ -// lazy val powerCommands: List[LoopCommand] = List( -// cmd("phase", "", "set the implicit phase for power commands", phaseCommand) -// ) - - private def importsCommand(line: String): Result = { - val tokens = words(line) - val handlers = intp.languageWildcardHandlers ++ intp.importHandlers - - handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach { - case (handler, idx) => - val (types, terms) = handler.importedSymbols partition (_.name.isTypeName) - val imps = handler.implicitSymbols - val found = tokens filter (handler importsSymbolNamed _) - val typeMsg = if (types.isEmpty) "" else types.size + " types" - val termMsg = if (terms.isEmpty) "" else terms.size + " terms" - val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit" - val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "") - val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")") - - intp.reporter.printMessage("%2d) %-30s %s%s".format( - idx + 1, - handler.importString, - statsMsg, - foundMsg - )) - } - } - - private def findToolsJar() = PathResolver.SupplementalLocations.platformTools + private val blockedCommands = Set("implicits", "javap", "power", "type", "kind") - private def addToolsJarToLoader() = { - val cl = findToolsJar() match { - case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader) - case _ => intp.classLoader - } - if (Javap.isAvailable(cl)) { - repldbg(":javap available.") - cl - } - else { - repldbg(":javap unavailable: no tools.jar at " + jdkHome) - intp.classLoader - } - } -// -// protected def newJavap() = -// JavapClass(addToolsJarToLoader(), new IMain.ReplStrippingWriter(intp), Some(intp)) -// -// private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap()) - - // Still todo: modules. -// private def typeCommand(line0: String): Result = { -// line0.trim match { -// case "" => ":type [-v] " -// case s => intp.typeCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ") -// } -// } - -// private def kindCommand(expr: String): Result = { -// expr.trim match { -// case "" => ":kind [-v] " -// case s => intp.kindCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ") -// } -// } - - private def warningsCommand(): Result = { - if (intp.lastWarnings.isEmpty) - "Can't find any cached warnings." - else - intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) } - } - - private def changeSettings(args: String): Result = { - def showSettings() = { - for (s <- settings.userSetSettings.toSeq.sorted) echo(s.toString) - } - def updateSettings() = { - // put aside +flag options - val (pluses, rest) = (args split "\\s+").toList partition (_.startsWith("+")) - val tmps = new Settings - val (ok, leftover) = tmps.processArguments(rest, processAll = true) - if (!ok) echo("Bad settings request.") - else if (leftover.nonEmpty) echo("Unprocessed settings.") - else { - // boolean flags set-by-user on tmp copy should be off, not on - val offs = tmps.userSetSettings filter (_.isInstanceOf[Settings#BooleanSetting]) - val (minuses, nonbools) = rest partition (arg => offs exists (_ respondsTo arg)) - // update non-flags - settings.processArguments(nonbools, processAll = true) - // also snag multi-value options for clearing, e.g. -Ylog: and -language: - for { - s <- settings.userSetSettings - if s.isInstanceOf[Settings#MultiStringSetting] || s.isInstanceOf[Settings#PhasesSetting] - if nonbools exists (arg => arg.head == '-' && arg.last == ':' && (s respondsTo arg.init)) - } s match { - case c: Clearable => c.clear() - case _ => - } - def update(bs: Seq[String], name: String=>String, setter: Settings#Setting=>Unit) = { - for (b <- bs) - settings.lookupSetting(name(b)) match { - case Some(s) => - if (s.isInstanceOf[Settings#BooleanSetting]) setter(s) - else echo(s"Not a boolean flag: $b") - case _ => - echo(s"Not an option: $b") - } - } - update(minuses, identity, _.tryToSetFromPropertyValue("false")) // turn off - update(pluses, "-" + _.drop(1), _.tryToSet(Nil)) // turn on - } - } - if (args.isEmpty) showSettings() else updateSettings() - } - - private def javapCommand(line: String): Result = { -// if (javap == null) -// ":javap unavailable, no tools.jar at %s. Set JDK_HOME.".format(jdkHome) -// else if (line == "") -// ":javap [-lcsvp] [path1 path2 ...]" -// else -// javap(words(line)) foreach { res => -// if (res.isError) return "Failed: " + res.value -// else res.show() -// } - } - - private def pathToPhaseWrapper = intp.originalPath("$r") + ".phased.atCurrent" - - private def phaseCommand(name: String): Result = { -// val phased: Phased = power.phased -// import phased.NoPhaseName -// -// if (name == "clear") { -// phased.set(NoPhaseName) -// intp.clearExecutionWrapper() -// "Cleared active phase." -// } -// else if (name == "") phased.get match { -// case NoPhaseName => "Usage: :phase (e.g. typer, erasure.next, erasure+3)" -// case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get) -// } -// else { -// val what = phased.parse(name) -// if (what.isEmpty || !phased.set(what)) -// "'" + name + "' does not appear to represent a valid phase." -// else { -// intp.setExecutionWrapper(pathToPhaseWrapper) -// val activeMessage = -// if (what.toString.length == name.length) "" + what -// else "%s (%s)".format(what, name) -// -// "Active phase is now: " + activeMessage -// } -// } - } + /** Standard commands **/ + lazy val sparkStandardCommands: List[SparkILoop.this.LoopCommand] = + standardCommands.filter(cmd => !blockedCommands(cmd.name)) /** Available commands */ - def commands: List[LoopCommand] = standardCommands ++ ( - // if (isReplPower) - // powerCommands - // else - Nil - ) - - val replayQuestionMessage = - """|That entry seems to have slain the compiler. Shall I replay - |your session? I can re-run each line except the last one. - |[y/n] - """.trim.stripMargin - - private val crashRecovery: PartialFunction[Throwable, Boolean] = { - case ex: Throwable => - val (err, explain) = ( - if (intp.isInitializeComplete) - (intp.global.throwableAsString(ex), "") - else - (ex.getMessage, "The compiler did not initialize.\n") - ) - echo(err) - - ex match { - case _: NoSuchMethodError | _: NoClassDefFoundError => - echo("\nUnrecoverable error.") - throw ex - case _ => - def fn(): Boolean = - try in.readYesOrNo(explain + replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) - catch { case _: RuntimeException => false } - - if (fn()) replay() - else echo("\nAbandoning crashed session.") - } - true - } - - // return false if repl should exit - def processLine(line: String): Boolean = { - import scala.concurrent.duration._ - Await.ready(globalFuture, 60.seconds) - - (line ne null) && (command(line) match { - case Result(false, _) => false - case Result(_, Some(line)) => addReplay(line) ; true - case _ => true - }) - } - - private def readOneLine() = { - out.flush() - in readLine prompt - } - - /** The main read-eval-print loop for the repl. It calls - * command() for each line of input, and stops when - * command() returns false. - */ - @tailrec final def loop() { - if ( try processLine(readOneLine()) catch crashRecovery ) - loop() - } - - /** interpret all lines from a specified file */ - def interpretAllFrom(file: File) { - savingReader { - savingReplayStack { - file applyReader { reader => - in = SimpleReader(reader, out, interactive = false) - echo("Loading " + file + "...") - loop() - } - } - } - } - - /** create a new interpreter and replay the given commands */ - def replay() { - reset() - if (replayCommandStack.isEmpty) - echo("Nothing to replay.") - else for (cmd <- replayCommands) { - echo("Replaying: " + cmd) // flush because maybe cmd will have its own output - command(cmd) - echo("") - } - } - def resetCommand() { - echo("Resetting interpreter state.") - if (replayCommandStack.nonEmpty) { - echo("Forgetting this session history:\n") - replayCommands foreach echo - echo("") - replayCommandStack = Nil - } - if (intp.namedDefinedTerms.nonEmpty) - echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", ")) - if (intp.definedTypes.nonEmpty) - echo("Forgetting defined types: " + intp.definedTypes.mkString(", ")) - - reset() - } - def reset() { - intp.reset() - unleashAndSetPhase() - } - - def lineCommand(what: String): Result = editCommand(what, None) - - // :edit id or :edit line - def editCommand(what: String): Result = editCommand(what, Properties.envOrNone("EDITOR")) - - def editCommand(what: String, editor: Option[String]): Result = { - def diagnose(code: String) = { - echo("The edited code is incomplete!\n") - val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}") - if (errless) echo("The compiler reports no errors.") - } - def historicize(text: String) = history match { - case jlh: JLineHistory => text.lines foreach jlh.add ; jlh.moveToEnd() ; true - case _ => false - } - def edit(text: String): Result = editor match { - case Some(ed) => - val tmp = File.makeTemp() - tmp.writeAll(text) - try { - val pr = new ProcessResult(s"$ed ${tmp.path}") - pr.exitCode match { - case 0 => - tmp.safeSlurp() match { - case Some(edited) if edited.trim.isEmpty => echo("Edited text is empty.") - case Some(edited) => - echo(edited.lines map ("+" + _) mkString "\n") - val res = intp interpret edited - if (res == IR.Incomplete) diagnose(edited) - else { - historicize(edited) - Result(lineToRecord = Some(edited), keepRunning = true) - } - case None => echo("Can't read edited text. Did you delete it?") - } - case x => echo(s"Error exit from $ed ($x), ignoring") - } - } finally { - tmp.delete() - } - case None => - if (historicize(text)) echo("Placing text in recent history.") - else echo(f"No EDITOR defined and you can't change history, echoing your text:%n$text") - } - - // if what is a number, use it as a line number or range in history - def isNum = what forall (c => c.isDigit || c == '-' || c == '+') - // except that "-" means last value - def isLast = (what == "-") - if (isLast || !isNum) { - val name = if (isLast) intp.mostRecentVar else what - val sym = intp.symbolOfIdent(name) - intp.prevRequestList collectFirst { case r if r.defines contains sym => r } match { - case Some(req) => edit(req.line) - case None => echo(s"No symbol in scope: $what") - } - } else try { - val s = what - // line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur) - val (start, len) = - if ((s indexOf '+') > 0) { - val (a,b) = s splitAt (s indexOf '+') - (a.toInt, b.drop(1).toInt) - } else { - (s indexOf '-') match { - case -1 => (s.toInt, 1) - case 0 => val n = s.drop(1).toInt ; (history.index - n, n) - case _ if s.last == '-' => val n = s.init.toInt ; (n, history.index - n) - case i => val n = s.take(i).toInt ; (n, s.drop(i+1).toInt - n) - } - } - import scala.collection.JavaConverters._ - val index = (start - 1) max 0 - val text = history match { - case jlh: JLineHistory => jlh.entries(index).asScala.take(len) map (_.value) mkString "\n" - case _ => history.asStrings.slice(index, index + len) mkString "\n" - } - edit(text) - } catch { - case _: NumberFormatException => echo(s"Bad range '$what'") - echo("Use line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)") - } - } - - /** fork a shell and run a command */ - lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { - override def usage = "" - def apply(line: String): Result = line match { - case "" => showUsage() - case _ => - val toRun = s"new ${classOf[ProcessResult].getName}(${string2codeQuoted(line)})" - intp interpret toRun - () - } - } - - def withFile[A](filename: String)(action: File => A): Option[A] = { - val res = Some(File(filename)) filter (_.exists) map action - if (res.isEmpty) echo("That file does not exist") // courtesy side-effect - res - } - - def loadCommand(arg: String) = { - var shouldReplay: Option[String] = None - withFile(arg)(f => { - interpretAllFrom(f) - shouldReplay = Some(":load " + arg) - }) - Result(keepRunning = true, shouldReplay) - } - - def saveCommand(filename: String): Result = ( - if (filename.isEmpty) echo("File name is required.") - else if (replayCommandStack.isEmpty) echo("No replay commands in session") - else File(filename).printlnAll(replayCommands: _*) - ) - - def addClasspath(arg: String): Unit = { - val f = File(arg).normalize - if (f.exists) { - addedClasspath = ClassPath.join(addedClasspath, f.path) - val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) - echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, totalClasspath)) - replay() - } - else echo("The path '" + f + "' doesn't seem to exist.") - } - - def powerCmd(): Result = { - if (isReplPower) "Already in power mode." - else enablePowerMode(isDuringInit = false) - } - def enablePowerMode(isDuringInit: Boolean) = { - replProps.power setValue true - unleashAndSetPhase() - // asyncEcho(isDuringInit, power.banner) - } - private def unleashAndSetPhase() { - if (isReplPower) { - // power.unleash() - // Set the phase to "typer" - // intp beSilentDuring phaseCommand("typer") - } - } - - def asyncEcho(async: Boolean, msg: => String) { - if (async) asyncMessage(msg) - else echo(msg) - } - - def verbosity() = { - val old = intp.printResults - intp.printResults = !old - echo("Switched " + (if (old) "off" else "on") + " result printing.") - } - - /** Run one command submitted by the user. Two values are returned: - * (1) whether to keep running, (2) the line to record for replay, - * if any. */ - def command(line: String): Result = { - if (line startsWith ":") { - val cmd = line.tail takeWhile (x => !x.isWhitespace) - uniqueCommand(cmd) match { - case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace)) - case _ => ambiguousError(cmd) - } - } - else if (intp.global == null) Result(keepRunning = false, None) // Notice failure to create compiler - else Result(keepRunning = true, interpretStartingWith(line)) - } - - private def readWhile(cond: String => Boolean) = { - Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) - } - - def pasteCommand(arg: String): Result = { - var shouldReplay: Option[String] = None - def result = Result(keepRunning = true, shouldReplay) - val (raw, file) = - if (arg.isEmpty) (false, None) - else { - val r = """(-raw)?(\s+)?([^\-]\S*)?""".r - arg match { - case r(flag, sep, name) => - if (flag != null && name != null && sep == null) - echo(s"""I assume you mean "$flag $name"?""") - (flag != null, Option(name)) - case _ => - echo("usage: :paste -raw file") - return result - } - } - val code = file match { - case Some(name) => - withFile(name)(f => { - shouldReplay = Some(s":paste $arg") - val s = f.slurp.trim - if (s.isEmpty) echo(s"File contains no code: $f") - else echo(s"Pasting file $f...") - s - }) getOrElse "" - case None => - echo("// Entering paste mode (ctrl-D to finish)\n") - val text = (readWhile(_ => true) mkString "\n").trim - if (text.isEmpty) echo("\n// Nothing pasted, nothing gained.\n") - else echo("\n// Exiting paste mode, now interpreting.\n") - text - } - def interpretCode() = { - val res = intp interpret code - // if input is incomplete, let the compiler try to say why - if (res == IR.Incomplete) { - echo("The pasted code is incomplete!\n") - // Remembrance of Things Pasted in an object - val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}") - if (errless) echo("...but compilation found no error? Good luck with that.") - } - } - def compileCode() = { - val errless = intp compileSources new BatchSourceFile("", code) - if (!errless) echo("There were compilation errors!") - } - if (code.nonEmpty) { - if (raw) compileCode() else interpretCode() - } - result - } - - private object paste extends Pasted { - val ContinueString = " | " - val PromptString = "scala> " - - def interpret(line: String): Unit = { - echo(line.trim) - intp interpret line - echo("") - } - - def transcript(start: String) = { - echo("\n// Detected repl transcript paste: ctrl-D to finish.\n") - apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim)) - } - } - import paste.{ ContinueString, PromptString } - - /** Interpret expressions starting with the first line. - * Read lines until a complete compilation unit is available - * or until a syntax error has been seen. If a full unit is - * read, go ahead and interpret it. Return the full string - * to be recorded for replay, if any. - */ - def interpretStartingWith(code: String): Option[String] = { - // signal completion non-completion input has been received - in.completion.resetVerbosity() - - def reallyInterpret = { - val reallyResult = intp.interpret(code) - (reallyResult, reallyResult match { - case IR.Error => None - case IR.Success => Some(code) - case IR.Incomplete => - if (in.interactive && code.endsWith("\n\n")) { - echo("You typed two blank lines. Starting a new command.") - None - } - else in.readLine(ContinueString) match { - case null => - // we know compilation is going to fail since we're at EOF and the - // parser thinks the input is still incomplete, but since this is - // a file being read non-interactively we want to fail. So we send - // it straight to the compiler for the nice error message. - intp.compileString(code) - None - - case line => interpretStartingWith(code + "\n" + line) - } - }) - } - - /** Here we place ourselves between the user and the interpreter and examine - * the input they are ostensibly submitting. We intervene in several cases: - * - * 1) If the line starts with "scala> " it is assumed to be an interpreter paste. - * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation - * on the previous result. - * 3) If the Completion object's execute returns Some(_), we inject that value - * and avoid the interpreter, as it's likely not valid scala code. - */ - if (code == "") None - else if (!paste.running && code.trim.startsWith(PromptString)) { - paste.transcript(code) - None - } - else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") { - interpretStartingWith(intp.mostRecentVar + code) - } - else if (code.trim startsWith "//") { - // line comment, do nothing - None - } - else - reallyInterpret._2 - } - - // runs :load `file` on any files passed via -i - def loadFiles(settings: Settings) = settings match { - case settings: GenericRunnerSettings => - for (filename <- settings.loadfiles.value) { - val cmd = ":load " + filename - command(cmd) - addReplay(cmd) - echo("") - } - case _ => - } - - /** Tries to create a JLineReader, falling back to SimpleReader: - * unless settings or properties are such that it should start - * with SimpleReader. - */ - def chooseReader(settings: Settings): InteractiveReader = { - if (settings.Xnojline || Properties.isEmacsShell) - SimpleReader() - else try new JLineReader( - if (settings.noCompletion) NoCompletion - else new SparkJLineCompletion(intp) - ) - catch { - case ex @ (_: Exception | _: NoClassDefFoundError) => - echo("Failed to created JLineReader: " + ex + "\nFalling back to SimpleReader.") - SimpleReader() - } - } - protected def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] = - u.TypeTag[T]( - m, - new TypeCreator { - def apply[U <: Universe with Singleton](m: Mirror[U]): U # Type = - m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type] - }) - - private def loopPostInit() { - // Bind intp somewhere out of the regular namespace where - // we can get at it in generated code. - intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfStaticClass[SparkIMain], classTag[SparkIMain])) - // Auto-run code via some setting. - ( replProps.replAutorunCode.option - flatMap (f => io.File(f).safeSlurp()) - foreach (intp quietRun _) - ) - // classloader and power mode setup - intp.setContextClassLoader() - if (isReplPower) { - // replProps.power setValue true - // unleashAndSetPhase() - // asyncMessage(power.banner) - } - // SI-7418 Now, and only now, can we enable TAB completion. - in match { - case x: JLineReader => x.consoleReader.postInit - case _ => - } - } - def process(settings: Settings): Boolean = savingContextLoader { - this.settings = settings - createInterpreter() - - // sets in to some kind of reader depending on environmental cues - in = in0.fold(chooseReader(settings))(r => SimpleReader(r, out, interactive = true)) - globalFuture = future { - intp.initializeSynchronous() - loopPostInit() - !intp.reporter.hasErrors - } - import scala.concurrent.duration._ - Await.ready(globalFuture, 10 seconds) - printWelcome() + override def commands: List[LoopCommand] = sparkStandardCommands + + /** + * We override `loadFiles` because we need to initialize Spark *before* the REPL + * sees any files, so that the Spark context is visible in those files. This is a bit of a + * hack, but there isn't another hook available to us at this point. + */ + override def loadFiles(settings: Settings): Unit = { initializeSpark() - loadFiles(settings) - - try loop() - catch AbstractOrMissingHandler() - finally closeInterpreter() - - true + super.loadFiles(settings) } - - @deprecated("Use `process` instead", "2.9.0") - def main(settings: Settings): Unit = process(settings) //used by sbt } object SparkILoop { - implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp - // Designed primarily for use by test code: take a String with a - // bunch of code, and prints out a transcript of what it would look - // like if you'd just typed it into the repl. - def runForTranscript(code: String, settings: Settings): String = { - import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - - stringFromStream { ostream => - Console.withOut(ostream) { - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) { - override def write(str: String) = { - // completely skip continuation lines - if (str forall (ch => ch.isWhitespace || ch == '|')) () - else super.write(str) - } - } - val input = new BufferedReader(new StringReader(code.trim + "\n")) { - override def readLine(): String = { - val s = super.readLine() - // helping out by printing the line being interpreted. - if (s != null) - // scalastyle:off println - output.println(s) - // scalastyle:on println - s - } - } - val repl = new SparkILoop(input, output) - if (settings.classpath.isDefault) - settings.classpath.value = sys.props("java.class.path") - - repl process settings - } - } - } - - /** Creates an interpreter loop with default settings and feeds - * the given code to it as input. - */ + /** + * Creates an interpreter loop with default settings and feeds + * the given code to it as input. + */ def run(code: String, sets: Settings = new Settings): String = { import java.io.{ BufferedReader, StringReader, OutputStreamWriter } stringFromStream { ostream => Console.withOut(ostream) { - val input = new BufferedReader(new StringReader(code)) - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) - val repl = new SparkILoop(input, output) + val input = new BufferedReader(new StringReader(code)) + val output = new JPrintWriter(new OutputStreamWriter(ostream), true) + val repl = new SparkILoop(input, output) if (sets.classpath.isDefault) sets.classpath.value = sys.props("java.class.path") diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala deleted file mode 100644 index 56c009a4e38e7..0000000000000 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ /dev/null @@ -1,1323 +0,0 @@ -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Martin Odersky - */ - -package scala -package tools.nsc -package interpreter - -import PartialFunction.cond -import scala.language.implicitConversions -import scala.beans.BeanProperty -import scala.collection.mutable -import scala.concurrent.{ Future, ExecutionContext } -import scala.reflect.runtime.{ universe => ru } -import scala.reflect.{ ClassTag, classTag } -import scala.reflect.internal.util.{ BatchSourceFile, SourceFile } -import scala.tools.util.PathResolver -import scala.tools.nsc.io.AbstractFile -import scala.tools.nsc.typechecker.{ TypeStrings, StructuredTypeStrings } -import scala.tools.nsc.util.{ ScalaClassLoader, stringFromReader, stringFromWriter, StackTraceOps } -import scala.tools.nsc.util.Exceptional.unwrap -import javax.script.{AbstractScriptEngine, Bindings, ScriptContext, ScriptEngine, ScriptEngineFactory, ScriptException, CompiledScript, Compilable} - -/** An interpreter for Scala code. - * - * The main public entry points are compile(), interpret(), and bind(). - * The compile() method loads a complete Scala file. The interpret() method - * executes one line of Scala code at the request of the user. The bind() - * method binds an object to a variable that can then be used by later - * interpreted code. - * - * The overall approach is based on compiling the requested code and then - * using a Java classloader and Java reflection to run the code - * and access its results. - * - * In more detail, a single compiler instance is used - * to accumulate all successfully compiled or interpreted Scala code. To - * "interpret" a line of code, the compiler generates a fresh object that - * includes the line of code and which has public member(s) to export - * all variables defined by that code. To extract the result of an - * interpreted line to show the user, a second "result object" is created - * which imports the variables exported by the above object and then - * exports members called "$eval" and "$print". To accomodate user expressions - * that read from variables or methods defined in previous statements, "import" - * statements are used. - * - * This interpreter shares the strengths and weaknesses of using the - * full compiler-to-Java. The main strength is that interpreted code - * behaves exactly as does compiled code, including running at full speed. - * The main weakness is that redefining classes and methods is not handled - * properly, because rebinding at the Java level is technically difficult. - * - * @author Moez A. Abdel-Gawad - * @author Lex Spoon - */ -class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Settings, - protected val out: JPrintWriter) extends AbstractScriptEngine with Compilable with SparkImports { - imain => - - setBindings(createBindings, ScriptContext.ENGINE_SCOPE) - object replOutput extends ReplOutput(settings.Yreploutdir) { } - - @deprecated("Use replOutput.dir instead", "2.11.0") - def virtualDirectory = replOutput.dir - // Used in a test case. - def showDirectory() = replOutput.show(out) - - private[nsc] var printResults = true // whether to print result lines - private[nsc] var totalSilence = false // whether to print anything - private var _initializeComplete = false // compiler is initialized - private var _isInitialized: Future[Boolean] = null // set up initialization future - private var bindExceptions = true // whether to bind the lastException variable - private var _executionWrapper = "" // code to be wrapped around all lines - - /** We're going to go to some trouble to initialize the compiler asynchronously. - * It's critical that nothing call into it until it's been initialized or we will - * run into unrecoverable issues, but the perceived repl startup time goes - * through the roof if we wait for it. So we initialize it with a future and - * use a lazy val to ensure that any attempt to use the compiler object waits - * on the future. - */ - private var _classLoader: util.AbstractFileClassLoader = null // active classloader - private val _compiler: ReplGlobal = newCompiler(settings, reporter) // our private compiler - - def compilerClasspath: Seq[java.net.URL] = ( - if (isInitializeComplete) global.classPath.asURLs - else new PathResolver(settings).result.asURLs // the compiler's classpath - ) - def settings = initialSettings - // Run the code body with the given boolean settings flipped to true. - def withoutWarnings[T](body: => T): T = beQuietDuring { - val saved = settings.nowarn.value - if (!saved) - settings.nowarn.value = true - - try body - finally if (!saved) settings.nowarn.value = false - } - - /** construct an interpreter that reports to Console */ - def this(settings: Settings, out: JPrintWriter) = this(null, settings, out) - def this(factory: ScriptEngineFactory, settings: Settings) = this(factory, settings, new NewLinePrintWriter(new ConsoleWriter, true)) - def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) - def this(factory: ScriptEngineFactory) = this(factory, new Settings()) - def this() = this(new Settings()) - - lazy val formatting: Formatting = new Formatting { - val prompt = Properties.shellPromptString - } - lazy val reporter: SparkReplReporter = new SparkReplReporter(this) - - import formatting._ - import reporter.{ printMessage, printUntruncatedMessage } - - // This exists mostly because using the reporter too early leads to deadlock. - private def echo(msg: String) { Console println msg } - private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }")) - private def _initialize() = { - try { - // if this crashes, REPL will hang its head in shame - val run = new _compiler.Run() - assert(run.typerPhase != NoPhase, "REPL requires a typer phase.") - run compileSources _initSources - _initializeComplete = true - true - } - catch AbstractOrMissingHandler() - } - private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" - private val logScope = scala.sys.props contains "scala.repl.scope" - // scalastyle:off println - private def scopelog(msg: String) = if (logScope) Console.err.println(msg) - // scalastyle:on println - - // argument is a thunk to execute after init is done - def initialize(postInitSignal: => Unit) { - synchronized { - if (_isInitialized == null) { - _isInitialized = - Future(try _initialize() finally postInitSignal)(ExecutionContext.global) - } - } - } - def initializeSynchronous(): Unit = { - if (!isInitializeComplete) { - _initialize() - assert(global != null, global) - } - } - def isInitializeComplete = _initializeComplete - - lazy val global: Global = { - if (!isInitializeComplete) _initialize() - _compiler - } - - import global._ - import definitions.{ ObjectClass, termMember, dropNullaryMethod} - - lazy val runtimeMirror = ru.runtimeMirror(classLoader) - - private def noFatal(body: => Symbol): Symbol = try body catch { case _: FatalError => NoSymbol } - - def getClassIfDefined(path: String) = ( - noFatal(runtimeMirror staticClass path) - orElse noFatal(rootMirror staticClass path) - ) - def getModuleIfDefined(path: String) = ( - noFatal(runtimeMirror staticModule path) - orElse noFatal(rootMirror staticModule path) - ) - - implicit class ReplTypeOps(tp: Type) { - def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp) - } - - // TODO: If we try to make naming a lazy val, we run into big time - // scalac unhappiness with what look like cycles. It has not been easy to - // reduce, but name resolution clearly takes different paths. - object naming extends { - val global: imain.global.type = imain.global - } with Naming { - // make sure we don't overwrite their unwisely named res3 etc. - def freshUserTermName(): TermName = { - val name = newTermName(freshUserVarName()) - if (replScope containsName name) freshUserTermName() - else name - } - def isInternalTermName(name: Name) = isInternalVarName("" + name) - } - import naming._ - - object deconstruct extends { - val global: imain.global.type = imain.global - } with StructuredTypeStrings - - lazy val memberHandlers = new { - val intp: imain.type = imain - } with SparkMemberHandlers - import memberHandlers._ - - /** Temporarily be quiet */ - def beQuietDuring[T](body: => T): T = { - val saved = printResults - printResults = false - try body - finally printResults = saved - } - def beSilentDuring[T](operation: => T): T = { - val saved = totalSilence - totalSilence = true - try operation - finally totalSilence = saved - } - - def quietRun[T](code: String) = beQuietDuring(interpret(code)) - - /** takes AnyRef because it may be binding a Throwable or an Exceptional */ - private def withLastExceptionLock[T](body: => T, alt: => T): T = { - assert(bindExceptions, "withLastExceptionLock called incorrectly.") - bindExceptions = false - - try beQuietDuring(body) - catch logAndDiscard("withLastExceptionLock", alt) - finally bindExceptions = true - } - - def executionWrapper = _executionWrapper - def setExecutionWrapper(code: String) = _executionWrapper = code - def clearExecutionWrapper() = _executionWrapper = "" - - /** interpreter settings */ - lazy val isettings = new SparkISettings(this) - - /** Instantiate a compiler. Overridable. */ - protected def newCompiler(settings: Settings, reporter: reporters.Reporter): ReplGlobal = { - settings.outputDirs setSingleOutput replOutput.dir - settings.exposeEmptyPackage.value = true - new Global(settings, reporter) with ReplGlobal { override def toString: String = "" } - } - - /** Parent classloader. Overridable. */ - protected def parentClassLoader: ClassLoader = - settings.explicitParentLoader.getOrElse( this.getClass.getClassLoader() ) - - /* A single class loader is used for all commands interpreted by this Interpreter. - It would also be possible to create a new class loader for each command - to interpret. The advantages of the current approach are: - - - Expressions are only evaluated one time. This is especially - significant for I/O, e.g. "val x = Console.readLine" - - The main disadvantage is: - - - Objects, classes, and methods cannot be rebound. Instead, definitions - shadow the old ones, and old code objects refer to the old - definitions. - */ - def resetClassLoader() = { - repldbg("Setting new classloader: was " + _classLoader) - _classLoader = null - ensureClassLoader() - } - final def ensureClassLoader() { - if (_classLoader == null) - _classLoader = makeClassLoader() - } - def classLoader: util.AbstractFileClassLoader = { - ensureClassLoader() - _classLoader - } - - def backticked(s: String): String = ( - (s split '.').toList map { - case "_" => "_" - case s if nme.keywords(newTermName(s)) => s"`$s`" - case s => s - } mkString "." - ) - def readRootPath(readPath: String) = getModuleIfDefined(readPath) - - abstract class PhaseDependentOps { - def shift[T](op: => T): T - - def path(name: => Name): String = shift(path(symbolOfName(name))) - def path(sym: Symbol): String = backticked(shift(sym.fullName)) - def sig(sym: Symbol): String = shift(sym.defString) - } - object typerOp extends PhaseDependentOps { - def shift[T](op: => T): T = exitingTyper(op) - } - object flatOp extends PhaseDependentOps { - def shift[T](op: => T): T = exitingFlatten(op) - } - - def originalPath(name: String): String = originalPath(name: TermName) - def originalPath(name: Name): String = typerOp path name - def originalPath(sym: Symbol): String = typerOp path sym - def flatPath(sym: Symbol): String = flatOp shift sym.javaClassName - def translatePath(path: String) = { - val sym = if (path endsWith "$") symbolOfTerm(path.init) else symbolOfIdent(path) - sym.toOption map flatPath - } - def translateEnclosingClass(n: String) = symbolOfTerm(n).enclClass.toOption map flatPath - - private class TranslatingClassLoader(parent: ClassLoader) extends util.AbstractFileClassLoader(replOutput.dir, parent) { - /** Overridden here to try translating a simple name to the generated - * class name if the original attempt fails. This method is used by - * getResourceAsStream as well as findClass. - */ - override protected def findAbstractFile(name: String): AbstractFile = - super.findAbstractFile(name) match { - case null if _initializeComplete => translatePath(name) map (super.findAbstractFile(_)) orNull - case file => file - } - } - private def makeClassLoader(): util.AbstractFileClassLoader = - new TranslatingClassLoader(parentClassLoader match { - case null => ScalaClassLoader fromURLs compilerClasspath - case p => new ScalaClassLoader.URLClassLoader(compilerClasspath, p) - }) - - // Set the current Java "context" class loader to this interpreter's class loader - def setContextClassLoader() = classLoader.setAsContext() - - def allDefinedNames: List[Name] = exitingTyper(replScope.toList.map(_.name).sorted) - def unqualifiedIds: List[String] = allDefinedNames map (_.decode) sorted - - /** Most recent tree handled which wasn't wholly synthetic. */ - private def mostRecentlyHandledTree: Option[Tree] = { - prevRequests.reverse foreach { req => - req.handlers.reverse foreach { - case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member) - case _ => () - } - } - None - } - - private def updateReplScope(sym: Symbol, isDefined: Boolean) { - def log(what: String) { - val mark = if (sym.isType) "t " else "v " - val name = exitingTyper(sym.nameString) - val info = cleanTypeAfterTyper(sym) - val defn = sym defStringSeenAs info - - scopelog(f"[$mark$what%6s] $name%-25s $defn%s") - } - if (ObjectClass isSubClass sym.owner) return - // unlink previous - replScope lookupAll sym.name foreach { sym => - log("unlink") - replScope unlink sym - } - val what = if (isDefined) "define" else "import" - log(what) - replScope enter sym - } - - def recordRequest(req: Request) { - if (req == null) - return - - prevRequests += req - - // warning about serially defining companions. It'd be easy - // enough to just redefine them together but that may not always - // be what people want so I'm waiting until I can do it better. - exitingTyper { - req.defines filterNot (s => req.defines contains s.companionSymbol) foreach { newSym => - val oldSym = replScope lookup newSym.name.companionName - if (Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule }) { - replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.") - replwarn("Companions must be defined together; you may wish to use :paste mode for this.") - } - } - } - exitingTyper { - req.imports foreach (sym => updateReplScope(sym, isDefined = false)) - req.defines foreach (sym => updateReplScope(sym, isDefined = true)) - } - } - - private[nsc] def replwarn(msg: => String) { - if (!settings.nowarnings) - printMessage(msg) - } - - def compileSourcesKeepingRun(sources: SourceFile*) = { - val run = new Run() - assert(run.typerPhase != NoPhase, "REPL requires a typer phase.") - reporter.reset() - run compileSources sources.toList - (!reporter.hasErrors, run) - } - - /** Compile an nsc SourceFile. Returns true if there are - * no compilation errors, or false otherwise. - */ - def compileSources(sources: SourceFile*): Boolean = - compileSourcesKeepingRun(sources: _*)._1 - - /** Compile a string. Returns true if there are no - * compilation errors, or false otherwise. - */ - def compileString(code: String): Boolean = - compileSources(new BatchSourceFile("