diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml index a9f757c3e2413..e2eb0683b6e59 100644 --- a/.github/workflows/master.yml +++ b/.github/workflows/master.yml @@ -4,6 +4,9 @@ on: push: branches: - master + pull_request: + branches: + - master jobs: build: @@ -12,16 +15,46 @@ jobs: strategy: matrix: java: [ '1.8', '11' ] - name: Build Spark with JDK ${{ matrix.java }} + hadoop: [ 'hadoop-2.7', 'hadoop-3.2' ] + exclude: + - java: '11' + hadoop: 'hadoop-2.7' + name: Build Spark with JDK ${{ matrix.java }} and ${{ matrix.hadoop }} steps: - uses: actions/checkout@master - name: Set up JDK ${{ matrix.java }} uses: actions/setup-java@v1 with: - version: ${{ matrix.java }} + java-version: ${{ matrix.java }} - name: Build with Maven run: | - export MAVEN_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=512m -Dorg.slf4j.simpleLogger.defaultLogLevel=WARN" + export MAVEN_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=1g -Dorg.slf4j.simpleLogger.defaultLogLevel=WARN" export MAVEN_CLI_OPTS="--no-transfer-progress" - ./build/mvn $MAVEN_CLI_OPTS -DskipTests -Pyarn -Pmesos -Pkubernetes -Phive -Phive-thriftserver -Phadoop-3.2 -Phadoop-cloud -Djava.version=${{ matrix.java }} package + ./build/mvn $MAVEN_CLI_OPTS -DskipTests -Pyarn -Pmesos -Pkubernetes -Phive -Phive-thriftserver -P${{ matrix.hadoop }} -Phadoop-cloud -Djava.version=${{ matrix.java }} package + + + lint: + runs-on: ubuntu-latest + name: Linters + steps: + - uses: actions/checkout@master + - uses: actions/setup-java@v1 + with: + java-version: '11' + - uses: actions/setup-python@v1 + with: + python-version: '3.x' + architecture: 'x64' + - name: Scala + run: ./dev/lint-scala + - name: Java + run: ./dev/lint-java + - name: Python + run: | + pip install flake8 sphinx numpy + ./dev/lint-python + - name: License + run: ./dev/check-license + - name: Dependencies + run: ./dev/test-dependencies.sh diff --git a/LICENSE-binary b/LICENSE-binary index ba20eea118687..d2a189a3fca11 100644 --- a/LICENSE-binary +++ b/LICENSE-binary @@ -218,6 +218,7 @@ javax.jdo:jdo-api joda-time:joda-time net.sf.opencsv:opencsv org.apache.derby:derby +org.ehcache:ehcache org.objenesis:objenesis org.roaringbitmap:RoaringBitmap org.scalanlp:breeze-macros_2.12 @@ -259,6 +260,7 @@ net.sf.supercsv:super-csv org.apache.arrow:arrow-format org.apache.arrow:arrow-memory org.apache.arrow:arrow-vector +org.apache.commons:commons-configuration2 org.apache.commons:commons-crypto org.apache.commons:commons-lang3 org.apache.hadoop:hadoop-annotations @@ -266,6 +268,7 @@ org.apache.hadoop:hadoop-auth org.apache.hadoop:hadoop-client org.apache.hadoop:hadoop-common org.apache.hadoop:hadoop-hdfs +org.apache.hadoop:hadoop-hdfs-client org.apache.hadoop:hadoop-mapreduce-client-app org.apache.hadoop:hadoop-mapreduce-client-common org.apache.hadoop:hadoop-mapreduce-client-core @@ -278,6 +281,21 @@ org.apache.hadoop:hadoop-yarn-server-common org.apache.hadoop:hadoop-yarn-server-web-proxy org.apache.httpcomponents:httpclient org.apache.httpcomponents:httpcore +org.apache.kerby:kerb-admin +org.apache.kerby:kerb-client +org.apache.kerby:kerb-common +org.apache.kerby:kerb-core +org.apache.kerby:kerb-crypto +org.apache.kerby:kerb-identity +org.apache.kerby:kerb-server +org.apache.kerby:kerb-simplekdc +org.apache.kerby:kerb-util +org.apache.kerby:kerby-asn1 +org.apache.kerby:kerby-config +org.apache.kerby:kerby-pkix +org.apache.kerby:kerby-util +org.apache.kerby:kerby-xdr +org.apache.kerby:token-provider org.apache.orc:orc-core org.apache.orc:orc-mapreduce org.mortbay.jetty:jetty @@ -292,14 +310,19 @@ com.fasterxml.jackson.core:jackson-annotations com.fasterxml.jackson.core:jackson-core com.fasterxml.jackson.core:jackson-databind com.fasterxml.jackson.dataformat:jackson-dataformat-yaml +com.fasterxml.jackson.jaxrs:jackson-jaxrs-base +com.fasterxml.jackson.jaxrs:jackson-jaxrs-json-provider com.fasterxml.jackson.module:jackson-module-jaxb-annotations com.fasterxml.jackson.module:jackson-module-paranamer com.fasterxml.jackson.module:jackson-module-scala_2.12 +com.fasterxml.woodstox:woodstox-core com.github.mifmif:generex +com.github.stephenc.jcip:jcip-annotations com.google.code.findbugs:jsr305 com.google.code.gson:gson com.google.inject:guice com.google.inject.extensions:guice-servlet +com.nimbusds:nimbus-jose-jwt com.twitter:parquet-hadoop-bundle commons-cli:commons-cli commons-dbcp:commons-dbcp @@ -313,6 +336,8 @@ javax.inject:javax.inject javax.validation:validation-api log4j:apache-log4j-extras log4j:log4j +net.minidev:accessors-smart +net.minidev:json-smart net.sf.jpam:jpam org.apache.avro:avro org.apache.avro:avro-ipc @@ -328,6 +353,7 @@ org.apache.directory.server:apacheds-i18n org.apache.directory.server:apacheds-kerberos-codec org.apache.htrace:htrace-core org.apache.ivy:ivy +org.apache.geronimo.specs:geronimo-jcache_1.0_spec org.apache.mesos:mesos org.apache.parquet:parquet-column org.apache.parquet:parquet-common @@ -369,6 +395,20 @@ org.eclipse.jetty:jetty-webapp org.eclipse.jetty:jetty-xml org.scala-lang.modules:scala-xml_2.12 org.opencypher:okapi-shade +com.github.joshelser:dropwizard-metrics-hadoop-metrics2-reporter +com.zaxxer.HikariCP +org.apache.hive:hive-common +org.apache.hive:hive-llap-common +org.apache.hive:hive-serde +org.apache.hive:hive-service-rpc +org.apache.hive:hive-shims-0.23 +org.apache.hive:hive-shims +org.apache.hive:hive-shims-scheduler +org.apache.hive:hive-storage-api +org.apache.hive:hive-vector-code-gen +org.datanucleus:javax.jdo +com.tdunning:json +org.apache.velocity:velocity core/src/main/java/org/apache/spark/util/collection/TimSort.java core/src/main/resources/org/apache/spark/ui/static/bootstrap* @@ -387,6 +427,7 @@ BSD 2-Clause ------------ com.github.luben:zstd-jni +dnsjava:dnsjava javolution:javolution com.esotericsoftware:kryo-shaded com.esotericsoftware:minlog @@ -394,8 +435,11 @@ com.esotericsoftware:reflectasm com.google.protobuf:protobuf-java org.codehaus.janino:commons-compiler org.codehaus.janino:janino +org.codehaus.woodstox:stax2-api jline:jline org.jodd:jodd-core +com.github.wendykierp:JTransforms +pl.edu.icm:JLargeArrays BSD 3-Clause @@ -408,6 +452,7 @@ org.antlr:stringtemplate org.antlr:antlr4-runtime antlr:antlr com.github.fommil.netlib:core +com.google.re2j:re2j com.thoughtworks.paranamer:paranamer org.scala-lang:scala-compiler org.scala-lang:scala-library @@ -433,8 +478,13 @@ is distributed under the 3-Clause BSD license. MIT License ----------- -org.spire-math:spire-macros_2.12 -org.spire-math:spire_2.12 +com.microsoft.sqlserver:mssql-jdbc +org.typelevel:spire_2.12 +org.typelevel:spire-macros_2.12 +org.typelevel:spire-platform_2.12 +org.typelevel:spire-util_2.12 +org.typelevel:algebra_2.12:jar +org.typelevel:cats-kernel_2.12 org.typelevel:machinist_2.12 net.razorvine:pyrolite org.slf4j:jcl-over-slf4j @@ -458,6 +508,7 @@ Common Development and Distribution License (CDDL) 1.0 javax.activation:activation http://www.oracle.com/technetwork/java/javase/tech/index-jsp-138795.html javax.xml.stream:stax-api https://jcp.org/en/jsr/detail?id=173 +javax.transaction:javax.transaction-api Common Development and Distribution License (CDDL) 1.1 @@ -496,11 +547,6 @@ Eclipse Public License (EPL) 2.0 jakarta.annotation:jakarta-annotation-api https://projects.eclipse.org/projects/ee4j.ca jakarta.ws.rs:jakarta.ws.rs-api https://github.com/eclipse-ee4j/jaxrs-api -Mozilla Public License (MPL) 1.1 --------------------------------- - -com.github.rwl:jtransforms https://sourceforge.net/projects/jtransforms/ - Python Software Foundation License ---------------------------------- diff --git a/NOTICE-binary b/NOTICE-binary index f93e088a9a731..80ddfd10a1874 100644 --- a/NOTICE-binary +++ b/NOTICE-binary @@ -1135,4 +1135,356 @@ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and -limitations under the License. \ No newline at end of file +limitations under the License. + +dropwizard-metrics-hadoop-metrics2-reporter +Copyright 2016 Josh Elser + +Hive Common +Copyright 2019 The Apache Software Foundation + +Hive Llap Common +Copyright 2019 The Apache Software Foundation + +Hive Serde +Copyright 2019 The Apache Software Foundation + +Hive Service RPC +Copyright 2019 The Apache Software Foundation + +Hive Shims 0.23 +Copyright 2019 The Apache Software Foundation + +Hive Shims Common +Copyright 2019 The Apache Software Foundation + +Hive Shims Scheduler +Copyright 2019 The Apache Software Foundation + +Hive Storage API +Copyright 2018 The Apache Software Foundation + +Hive Vector-Code-Gen Utilities +Copyright 2019 The Apache Software Foundation + + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + Copyright 2015-2015 DataNucleus + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +Apache Velocity + +Copyright (C) 2000-2007 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +Apache Yetus - Audience Annotations +Copyright 2015-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +Ehcache V3 +Copyright 2014-2016 Terracotta, Inc. + +The product includes software from the Apache Commons Lang project, +under the Apache License 2.0 (see: org.ehcache.impl.internal.classes.commonslang) + +Apache Geronimo JCache Spec 1.0 +Copyright 2003-2014 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Kerby-kerb Admin +Copyright 2014-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Kerby-kerb Client +Copyright 2014-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Kerby-kerb Common +Copyright 2014-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Kerby-kerb core +Copyright 2014-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Kerby-kerb Crypto +Copyright 2014-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Kerby-kerb Identity +Copyright 2014-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Kerby-kerb Server +Copyright 2014-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Kerb Simple Kdc +Copyright 2014-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Kerby-kerb Util +Copyright 2014-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Kerby ASN1 Project +Copyright 2014-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Kerby Config +Copyright 2014-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Kerby PKIX Project +Copyright 2014-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Kerby Util +Copyright 2014-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Kerby XDR Project +Copyright 2014-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Token provider +Copyright 2014-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). diff --git a/R/check-cran.sh b/R/check-cran.sh index 22cc9c6b601fc..22c8f423cfd12 100755 --- a/R/check-cran.sh +++ b/R/check-cran.sh @@ -65,6 +65,10 @@ fi echo "Running CRAN check with $CRAN_CHECK_OPTIONS options" +# Remove this environment variable to allow to check suggested packages once +# Jenkins installs arrow. See SPARK-29339. +export _R_CHECK_FORCE_SUGGESTS_=FALSE + if [ -n "$NO_TESTS" ] && [ -n "$NO_MANUAL" ] then "$R_SCRIPT_PATH/R" CMD check $CRAN_CHECK_OPTIONS "SparkR_$VERSION.tar.gz" diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index f4780862099d3..95d3e52bef3a9 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -22,7 +22,8 @@ Suggests: rmarkdown, testthat, e1071, - survival + survival, + arrow Collate: 'schema.R' 'generics.R' diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 43ea27b359a9c..f27ef4ee28f16 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -148,19 +148,7 @@ getDefaultSqlSource <- function() { } writeToFileInArrow <- function(fileName, rdf, numPartitions) { - requireNamespace1 <- requireNamespace - - # R API in Arrow is not yet released in CRAN. CRAN requires to add the - # package in requireNamespace at DESCRIPTION. Later, CRAN checks if the package is available - # or not. Therefore, it works around by avoiding direct requireNamespace. - # Currently, as of Arrow 0.12.0, it can be installed by install_github. See ARROW-3204. - if (requireNamespace1("arrow", quietly = TRUE)) { - record_batch <- get("record_batch", envir = asNamespace("arrow"), inherits = FALSE) - RecordBatchStreamWriter <- get( - "RecordBatchStreamWriter", envir = asNamespace("arrow"), inherits = FALSE) - FileOutputStream <- get( - "FileOutputStream", envir = asNamespace("arrow"), inherits = FALSE) - + if (requireNamespace("arrow", quietly = TRUE)) { numPartitions <- if (!is.null(numPartitions)) { numToInt(numPartitions) } else { @@ -176,11 +164,11 @@ writeToFileInArrow <- function(fileName, rdf, numPartitions) { stream_writer <- NULL tryCatch({ for (rdf_slice in rdf_slices) { - batch <- record_batch(rdf_slice) + batch <- arrow::record_batch(rdf_slice) if (is.null(stream_writer)) { - stream <- FileOutputStream(fileName) + stream <- arrow::FileOutputStream(fileName) schema <- batch$schema - stream_writer <- RecordBatchStreamWriter(stream, schema) + stream_writer <- arrow::RecordBatchStreamWriter(stream, schema) } stream_writer$write_batch(batch) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 51ae2d2954a9a..93ba1307043a3 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -301,7 +301,7 @@ broadcastRDD <- function(sc, object) { #' Set the checkpoint directory #' #' Set the directory under which RDDs are going to be checkpointed. The -#' directory must be a HDFS path if running on a cluster. +#' directory must be an HDFS path if running on a cluster. #' #' @param sc Spark Context to use #' @param dirName Directory path @@ -446,7 +446,7 @@ setLogLevel <- function(level) { #' Set checkpoint directory #' #' Set the directory under which SparkDataFrame are going to be checkpointed. The directory must be -#' a HDFS path if running on a cluster. +#' an HDFS path if running on a cluster. #' #' @rdname setCheckpointDir #' @param directory Directory path to checkpoint to diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index b38d245a0cca7..a6febb1cbd132 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -232,11 +232,7 @@ readMultipleObjectsWithKeys <- function(inputCon) { } readDeserializeInArrow <- function(inputCon) { - # This is a hack to avoid CRAN check. Arrow is not uploaded into CRAN now. See ARROW-3204. - requireNamespace1 <- requireNamespace - if (requireNamespace1("arrow", quietly = TRUE)) { - RecordBatchStreamReader <- get( - "RecordBatchStreamReader", envir = asNamespace("arrow"), inherits = FALSE) + if (requireNamespace("arrow", quietly = TRUE)) { # Arrow drops `as_tibble` since 0.14.0, see ARROW-5190. useAsTibble <- exists("as_tibble", envir = asNamespace("arrow")) @@ -246,7 +242,7 @@ readDeserializeInArrow <- function(inputCon) { # for now. dataLen <- readInt(inputCon) arrowData <- readBin(inputCon, raw(), as.integer(dataLen), endian = "big") - batches <- RecordBatchStreamReader(arrowData)$batches() + batches <- arrow::RecordBatchStreamReader(arrowData)$batches() if (useAsTibble) { as_tibble <- get("as_tibble", envir = asNamespace("arrow")) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index eecb84572a30b..eec221c2be4bf 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3617,7 +3617,7 @@ setMethod("size", #' @details #' \code{slice}: Returns an array containing all the elements in x from the index start -#' (or starting from the end if start is negative) with the specified length. +#' (array indices start at 1, or from the end if start is negative) with the specified length. #' #' @rdname column_collection_functions #' @param start an index indicating the first element occurring in the result. diff --git a/R/pkg/R/mllib_recommendation.R b/R/pkg/R/mllib_recommendation.R index 9a77b07462585..d238ff93ed245 100644 --- a/R/pkg/R/mllib_recommendation.R +++ b/R/pkg/R/mllib_recommendation.R @@ -82,6 +82,12 @@ setClass("ALSModel", representation(jobj = "jobj")) #' statsS <- summary(modelS) #' } #' @note spark.als since 2.1.0 +#' @note the input rating dataframe to the ALS implementation should be deterministic. +#' Nondeterministic data can cause failure during fitting ALS model. For example, +#' an order-sensitive operation like sampling after a repartition makes dataframe output +#' nondeterministic, like \code{sample(repartition(df, 2L), FALSE, 0.5, 1618L)}. +#' Checkpointing sampled dataframe or adding a sort before sampling can help make the +#' dataframe deterministic. setMethod("spark.als", signature(data = "SparkDataFrame"), function(data, ratingCol = "rating", userCol = "user", itemCol = "item", rank = 10, regParam = 0.1, maxIter = 10, nonnegative = FALSE, diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 0d6f32c8f7e1f..cb3c1c59d12ed 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -222,15 +222,11 @@ writeArgs <- function(con, args) { } writeSerializeInArrow <- function(conn, df) { - # This is a hack to avoid CRAN check. Arrow is not uploaded into CRAN now. See ARROW-3204. - requireNamespace1 <- requireNamespace - if (requireNamespace1("arrow", quietly = TRUE)) { - write_arrow <- get("write_arrow", envir = asNamespace("arrow"), inherits = FALSE) - + if (requireNamespace("arrow", quietly = TRUE)) { # There looks no way to send each batch in streaming format via socket # connection. See ARROW-4512. # So, it writes the whole Arrow streaming-formatted binary at once for now. - writeRaw(conn, write_arrow(df, raw())) + writeRaw(conn, arrow::write_arrow(df, raw())) } else { stop("'arrow' package should be installed.") } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 31b986c326d0c..cdb59093781fb 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -266,11 +266,12 @@ sparkR.sparkContext <- function( #' df <- read.json(path) #' #' sparkR.session("local[2]", "SparkR", "/home/spark") -#' sparkR.session("yarn-client", "SparkR", "/home/spark", -#' list(spark.executor.memory="4g"), +#' sparkR.session("yarn", "SparkR", "/home/spark", +#' list(spark.executor.memory="4g", spark.submit.deployMode="client"), #' c("one.jar", "two.jar", "three.jar"), #' c("com.databricks:spark-avro_2.12:2.0.1")) -#' sparkR.session(spark.master = "yarn-client", spark.executor.memory = "4g") +#' sparkR.session(spark.master = "yarn", spark.submit.deployMode = "client", +# spark.executor.memory = "4g") #'} #' @note sparkR.session since 2.0.0 sparkR.session <- function( diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 80dc4ee634512..dfe69b7f4f1fb 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -50,7 +50,7 @@ compute <- function(mode, partition, serializer, deserializer, key, } else { # Check to see if inputData is a valid data.frame stopifnot(deserializer == "byte" || deserializer == "arrow") - stopifnot(class(inputData) == "data.frame") + stopifnot(is.data.frame(inputData)) } if (mode == 2) { diff --git a/R/pkg/tests/fulltests/test_sparkR.R b/R/pkg/tests/fulltests/test_sparkR.R index f73fc6baeccef..4232f5ec430f6 100644 --- a/R/pkg/tests/fulltests/test_sparkR.R +++ b/R/pkg/tests/fulltests/test_sparkR.R @@ -36,8 +36,8 @@ test_that("sparkCheckInstall", { # "yarn-client, mesos-client" mode, SPARK_HOME was not set sparkHome <- "" - master <- "yarn-client" - deployMode <- "" + master <- "yarn" + deployMode <- "client" expect_error(sparkCheckInstall(sparkHome, master, deployMode)) sparkHome <- "" master <- "" diff --git a/R/pkg/tests/fulltests/test_sparkSQL_arrow.R b/R/pkg/tests/fulltests/test_sparkSQL_arrow.R index 825c7423e1579..97972753a78fa 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL_arrow.R +++ b/R/pkg/tests/fulltests/test_sparkSQL_arrow.R @@ -101,7 +101,7 @@ test_that("dapply() Arrow optimization", { tryCatch({ ret <- dapply(df, function(rdf) { - stopifnot(class(rdf) == "data.frame") + stopifnot(is.data.frame(rdf)) rdf }, schema(df)) @@ -115,7 +115,7 @@ test_that("dapply() Arrow optimization", { tryCatch({ ret <- dapply(df, function(rdf) { - stopifnot(class(rdf) == "data.frame") + stopifnot(is.data.frame(rdf)) # mtcars' hp is more then 50. stopifnot(all(rdf$hp > 50)) rdf @@ -199,7 +199,7 @@ test_that("gapply() Arrow optimization", { if (length(key) > 0) { stopifnot(is.numeric(key[[1]])) } - stopifnot(class(grouped) == "data.frame") + stopifnot(is.data.frame(grouped)) grouped }, schema(df)) @@ -217,7 +217,7 @@ test_that("gapply() Arrow optimization", { if (length(key) > 0) { stopifnot(is.numeric(key[[1]])) } - stopifnot(class(grouped) == "data.frame") + stopifnot(is.data.frame(grouped)) stopifnot(length(colnames(grouped)) == 11) # mtcars' hp is more then 50. stopifnot(all(grouped$hp > 50)) diff --git a/appveyor.yml b/appveyor.yml index a61436c5d2e68..00c688ba18eb6 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -42,13 +42,13 @@ install: # Install maven and dependencies - ps: .\dev\appveyor-install-dependencies.ps1 # Required package for R unit tests - - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'e1071', 'survival'), repos='https://cloud.r-project.org/')" + - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'e1071', 'survival', 'arrow'), repos='https://cloud.r-project.org/')" # Here, we use the fixed version of testthat. For more details, please see SPARK-22817. # As of devtools 2.1.0, it requires testthat higher then 2.1.1 as a dependency. SparkR test requires testthat 1.0.2. # Therefore, we don't use devtools but installs it directly from the archive including its dependencies. - cmd: R -e "install.packages(c('crayon', 'praise', 'R6'), repos='https://cloud.r-project.org/')" - cmd: R -e "install.packages('https://cloud.r-project.org/src/contrib/Archive/testthat/testthat_1.0.2.tar.gz', repos=NULL, type='source')" - - cmd: R -e "packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival')" + - cmd: R -e "packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival'); packageVersion('arrow')" build_script: # '-Djna.nosys=true' is required to avoid kernel32.dll load failure. diff --git a/build/mvn b/build/mvn index f68377b3ddc71..3628be9880253 100755 --- a/build/mvn +++ b/build/mvn @@ -22,7 +22,7 @@ _DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" # Preserve the calling directory _CALLING_DIR="$(pwd)" # Options used during compilation -_COMPILE_JVM_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=512m" +_COMPILE_JVM_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=1g" # Installs any application tarball given a URL, the expected tarball name, # and, optionally, a checkable binary path to determine if the binary has diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java index 6af45aec3c7b2..b33c53871c32f 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java @@ -252,7 +252,7 @@ private static Predicate getPredicate( return (value) -> set.contains(indexValueForEntity(getter, value)); } else { - HashSet set = new HashSet<>(values.size()); + HashSet> set = new HashSet<>(values.size()); for (Object key : values) { set.add(asKey(key)); } diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java index b8c5fab8709ed..d2a26982d8703 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java @@ -124,7 +124,7 @@ interface Accessor { Object get(Object instance) throws ReflectiveOperationException; - Class getType(); + Class getType(); } private class FieldAccessor implements Accessor { @@ -141,7 +141,7 @@ public Object get(Object instance) throws ReflectiveOperationException { } @Override - public Class getType() { + public Class getType() { return field.getType(); } } @@ -160,7 +160,7 @@ public Object get(Object instance) throws ReflectiveOperationException { } @Override - public Class getType() { + public Class getType() { return method.getReturnType(); } } diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index c107af9ceb415..2ee17800c10e4 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -35,6 +35,12 @@ + + + org.scala-lang + scala-library + + io.netty @@ -87,13 +93,6 @@ - - - org.scala-lang - scala-library - ${scala.version} - test - log4j log4j diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 53835d8304866..c9ef9f918ffd1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -293,9 +293,8 @@ public void close() { } connectionPool.clear(); - if (workerGroup != null) { + if (workerGroup != null && !workerGroup.isShuttingDown()) { workerGroup.shutdownGracefully(); - workerGroup = null; } } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java index 736059fdd1f57..490915f6de4b3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java @@ -112,4 +112,27 @@ public static int[] decode(ByteBuf buf) { return ints; } } + + /** Long integer arrays are encoded with their length followed by long integers. */ + public static class LongArrays { + public static int encodedLength(long[] longs) { + return 4 + 8 * longs.length; + } + + public static void encode(ByteBuf buf, long[] longs) { + buf.writeInt(longs.length); + for (long i : longs) { + buf.writeLong(i); + } + } + + public static long[] decode(ByteBuf buf) { + int numLongs = buf.readInt(); + long[] longs = new long[numLongs]; + for (int i = 0; i < longs.length; i ++) { + longs[i] = buf.readLong(); + } + return longs; + } + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 2aec4a33bbe43..9b76981c31c57 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -217,4 +217,11 @@ public Iterable> getAll() { assertFalse(c1.isActive()); } } + + @Test(expected = IOException.class) + public void closeFactoryBeforeCreateClient() throws IOException, InterruptedException { + TransportClientFactory factory = context.createClientFactory(); + factory.close(); + factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index 037e5cf7e5222..2d7a72315cf23 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -106,7 +106,7 @@ protected void handleMessage( numBlockIds += ids.length; } streamId = streamManager.registerStream(client.getClientId(), - new ManagedBufferIterator(msg, numBlockIds), client.getChannel()); + new ShuffleManagedBufferIterator(msg), client.getChannel()); } else { // For the compatibility with the old version, still keep the support for OpenBlocks. OpenBlocks msg = (OpenBlocks) msgObj; @@ -299,21 +299,6 @@ private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) { return mapIdAndReduceIds; } - ManagedBufferIterator(FetchShuffleBlocks msg, int numBlockIds) { - final int[] mapIdAndReduceIds = new int[2 * numBlockIds]; - int idx = 0; - for (int i = 0; i < msg.mapIds.length; i++) { - for (int reduceId : msg.reduceIds[i]) { - mapIdAndReduceIds[idx++] = msg.mapIds[i]; - mapIdAndReduceIds[idx++] = reduceId; - } - } - assert(idx == 2 * numBlockIds); - size = mapIdAndReduceIds.length; - blockDataForIndexFn = index -> blockManager.getBlockData(msg.appId, msg.execId, - msg.shuffleId, mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); - } - @Override public boolean hasNext() { return index < size; @@ -328,6 +313,49 @@ public ManagedBuffer next() { } } + private class ShuffleManagedBufferIterator implements Iterator { + + private int mapIdx = 0; + private int reduceIdx = 0; + + private final String appId; + private final String execId; + private final int shuffleId; + private final long[] mapIds; + private final int[][] reduceIds; + + ShuffleManagedBufferIterator(FetchShuffleBlocks msg) { + appId = msg.appId; + execId = msg.execId; + shuffleId = msg.shuffleId; + mapIds = msg.mapIds; + reduceIds = msg.reduceIds; + } + + @Override + public boolean hasNext() { + // mapIds.length must equal to reduceIds.length, and the passed in FetchShuffleBlocks + // must have non-empty mapIds and reduceIds, see the checking logic in + // OneForOneBlockFetcher. + assert(mapIds.length != 0 && mapIds.length == reduceIds.length); + return mapIdx < mapIds.length && reduceIdx < reduceIds[mapIdx].length; + } + + @Override + public ManagedBuffer next() { + final ManagedBuffer block = blockManager.getBlockData( + appId, execId, shuffleId, mapIds[mapIdx], reduceIds[mapIdx][reduceIdx]); + if (reduceIdx < reduceIds[mapIdx].length - 1) { + reduceIdx += 1; + } else { + reduceIdx = 0; + mapIdx += 1; + } + metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); + return block; + } + } + @Override public void channelActive(TransportClient client) { metrics.activeConnections.inc(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 50f16fc700f12..8b0d1e145a813 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -172,7 +172,7 @@ public ManagedBuffer getBlockData( String appId, String execId, int shuffleId, - int mapId, + long mapId, int reduceId) { ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); if (executor == null) { @@ -296,7 +296,7 @@ private void deleteNonShuffleServiceServedFiles(String[] dirs) { * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId. */ private ManagedBuffer getSortBasedShuffleBlockData( - ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) { + ExecutorShuffleInfo executor, int shuffleId, long mapId, int reduceId) { File indexFile = ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, "shuffle_" + shuffleId + "_" + mapId + "_0.index"); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index cc11e92067375..52854c86be3e6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -24,6 +24,8 @@ import java.util.HashMap; import com.google.common.primitives.Ints; +import com.google.common.primitives.Longs; +import org.apache.commons.lang3.tuple.ImmutableTriple; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -111,21 +113,21 @@ private boolean isShuffleBlocks(String[] blockIds) { */ private FetchShuffleBlocks createFetchShuffleBlocksMsg( String appId, String execId, String[] blockIds) { - int shuffleId = splitBlockId(blockIds[0])[0]; - HashMap> mapIdToReduceIds = new HashMap<>(); + int shuffleId = splitBlockId(blockIds[0]).left; + HashMap> mapIdToReduceIds = new HashMap<>(); for (String blockId : blockIds) { - int[] blockIdParts = splitBlockId(blockId); - if (blockIdParts[0] != shuffleId) { + ImmutableTriple blockIdParts = splitBlockId(blockId); + if (blockIdParts.left != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId); } - int mapId = blockIdParts[1]; + long mapId = blockIdParts.middle; if (!mapIdToReduceIds.containsKey(mapId)) { mapIdToReduceIds.put(mapId, new ArrayList<>()); } - mapIdToReduceIds.get(mapId).add(blockIdParts[2]); + mapIdToReduceIds.get(mapId).add(blockIdParts.right); } - int[] mapIds = Ints.toArray(mapIdToReduceIds.keySet()); + long[] mapIds = Longs.toArray(mapIdToReduceIds.keySet()); int[][] reduceIdArr = new int[mapIds.length][]; for (int i = 0; i < mapIds.length; i++) { reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i])); @@ -134,17 +136,16 @@ private FetchShuffleBlocks createFetchShuffleBlocksMsg( } /** Split the shuffleBlockId and return shuffleId, mapId and reduceId. */ - private int[] splitBlockId(String blockId) { + private ImmutableTriple splitBlockId(String blockId) { String[] blockIdParts = blockId.split("_"); if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { throw new IllegalArgumentException( "Unexpected shuffle block id format: " + blockId); } - return new int[] { - Integer.parseInt(blockIdParts[1]), - Integer.parseInt(blockIdParts[2]), - Integer.parseInt(blockIdParts[3]) - }; + return new ImmutableTriple<>( + Integer.parseInt(blockIdParts[1]), + Long.parseLong(blockIdParts[2]), + Integer.parseInt(blockIdParts[3])); } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java index 466eeb3e048a8..faa960d414bcc 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java @@ -34,14 +34,14 @@ public class FetchShuffleBlocks extends BlockTransferMessage { public final int shuffleId; // The length of mapIds must equal to reduceIds.size(), for the i-th mapId in mapIds, // it corresponds to the i-th int[] in reduceIds, which contains all reduce id for this map id. - public final int[] mapIds; + public final long[] mapIds; public final int[][] reduceIds; public FetchShuffleBlocks( String appId, String execId, int shuffleId, - int[] mapIds, + long[] mapIds, int[][] reduceIds) { this.appId = appId; this.execId = execId; @@ -98,7 +98,7 @@ public int encodedLength() { return Encoders.Strings.encodedLength(appId) + Encoders.Strings.encodedLength(execId) + 4 /* encoded length of shuffleId */ - + Encoders.IntArrays.encodedLength(mapIds) + + Encoders.LongArrays.encodedLength(mapIds) + 4 /* encoded length of reduceIds.size() */ + encodedLengthOfReduceIds; } @@ -108,7 +108,7 @@ public void encode(ByteBuf buf) { Encoders.Strings.encode(buf, appId); Encoders.Strings.encode(buf, execId); buf.writeInt(shuffleId); - Encoders.IntArrays.encode(buf, mapIds); + Encoders.LongArrays.encode(buf, mapIds); buf.writeInt(reduceIds.length); for (int[] ids: reduceIds) { Encoders.IntArrays.encode(buf, ids); @@ -119,7 +119,7 @@ public static FetchShuffleBlocks decode(ByteBuf buf) { String appId = Encoders.Strings.decode(buf); String execId = Encoders.Strings.decode(buf); int shuffleId = buf.readInt(); - int[] mapIds = Encoders.IntArrays.decode(buf); + long[] mapIds = Encoders.LongArrays.decode(buf); int reduceIdsSize = buf.readInt(); int[][] reduceIds = new int[reduceIdsSize][]; for (int i = 0; i < reduceIdsSize; i++) { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index 649c471dc1679..ba40f4a45ac8f 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -29,7 +29,7 @@ public class BlockTransferMessagesSuite { public void serializeOpenShuffleBlocks() { checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" })); checkSerializeDeserialize(new FetchShuffleBlocks( - "app-1", "exec-2", 0, new int[] {0, 1}, + "app-1", "exec-2", 0, new long[] {0, 1}, new int[][] {{ 0, 1 }, { 0, 1, 2 }})); checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"))); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 9c623a70424b6..6a5d04b6f417b 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -101,7 +101,7 @@ public void testFetchShuffleBlocks() { when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(blockMarkers[1]); FetchShuffleBlocks fetchShuffleBlocks = new FetchShuffleBlocks( - "app0", "exec1", 0, new int[] { 0 }, new int[][] {{ 0, 1 }}); + "app0", "exec1", 0, new long[] { 0 }, new int[][] {{ 0, 1 }}); checkOpenBlocksReceive(fetchShuffleBlocks, blockMarkers); verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 66633cc7a3595..26a11672b8068 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -64,7 +64,7 @@ public void testFetchOne() { BlockFetchingListener listener = fetchBlocks( blocks, blockIds, - new FetchShuffleBlocks("app-id", "exec-id", 0, new int[] { 0 }, new int[][] {{ 0 }}), + new FetchShuffleBlocks("app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0 }}), conf); verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0")); @@ -100,7 +100,7 @@ public void testFetchThreeShuffleBlocks() { BlockFetchingListener listener = fetchBlocks( blocks, blockIds, - new FetchShuffleBlocks("app-id", "exec-id", 0, new int[] { 0 }, new int[][] {{ 0, 1, 2 }}), + new FetchShuffleBlocks("app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0, 1, 2 }}), conf); for (int i = 0; i < 3; i ++) { diff --git a/common/tags/src/test/java/org/apache/spark/tags/ExtendedSQLTest.java b/common/tags/src/test/java/org/apache/spark/tags/ExtendedSQLTest.java new file mode 100644 index 0000000000000..1c0fff1b4045d --- /dev/null +++ b/common/tags/src/test/java/org/apache/spark/tags/ExtendedSQLTest.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.tags; + +import org.scalatest.TagAnnotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface ExtendedSQLTest { } diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index fdb81a06d41c9..72aa682bb95bc 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.unsafe.types import org.apache.commons.text.similarity.LevenshteinDistance import org.scalacheck.{Arbitrary, Gen} -import org.scalatest.prop.GeneratorDrivenPropertyChecks +import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks // scalastyle:off import org.scalatest.{FunSuite, Matchers} @@ -28,7 +28,7 @@ import org.apache.spark.unsafe.types.UTF8String.{fromString => toUTF8} /** * This TestSuite utilize ScalaCheck to generate randomized inputs for UTF8String testing. */ -class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenPropertyChecks with Matchers { +class UTF8StringPropertyCheckSuite extends FunSuite with ScalaCheckDrivenPropertyChecks with Matchers { // scalastyle:on test("toString") { diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index da0b06d295252..f52d33fd64223 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -113,6 +113,15 @@ # /metrics/applications/json # App information # /metrics/master/json # Master information +# org.apache.spark.metrics.sink.PrometheusServlet +# Name: Default: Description: +# path VARIES* Path prefix from the web server root +# +# * Default path is /metrics/prometheus for all instances except the master. The +# master has two paths: +# /metrics/applications/prometheus # App information +# /metrics/master/prometheus # Master information + # org.apache.spark.metrics.sink.GraphiteSink # Name: Default: Description: # host NONE Hostname of the Graphite server, must be set @@ -192,4 +201,10 @@ #driver.source.jvm.class=org.apache.spark.metrics.source.JvmSource -#executor.source.jvm.class=org.apache.spark.metrics.source.JvmSource \ No newline at end of file +#executor.source.jvm.class=org.apache.spark.metrics.source.JvmSource + +# Example configuration for PrometheusServlet +#*.sink.prometheusServlet.class=org.apache.spark.metrics.sink.PrometheusServlet +#*.sink.prometheusServlet.path=/metrics/prometheus +#master.sink.prometheusServlet.path=/metrics/master/prometheus +#applications.sink.prometheusServlet.path=/metrics/applications/prometheus diff --git a/core/benchmarks/CoalescedRDDBenchmark-jdk11-results.txt b/core/benchmarks/CoalescedRDDBenchmark-jdk11-results.txt new file mode 100644 index 0000000000000..e944111ff9e93 --- /dev/null +++ b/core/benchmarks/CoalescedRDDBenchmark-jdk11-results.txt @@ -0,0 +1,40 @@ +================================================================================================ +Coalesced RDD , large scale +================================================================================================ + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Coalesced RDD: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Coalesce Num Partitions: 100 Num Hosts: 1 344 360 14 0.3 3441.4 1.0X +Coalesce Num Partitions: 100 Num Hosts: 5 283 301 22 0.4 2825.1 1.2X +Coalesce Num Partitions: 100 Num Hosts: 10 270 271 2 0.4 2700.5 1.3X +Coalesce Num Partitions: 100 Num Hosts: 20 272 273 1 0.4 2721.1 1.3X +Coalesce Num Partitions: 100 Num Hosts: 40 271 272 1 0.4 2710.0 1.3X +Coalesce Num Partitions: 100 Num Hosts: 80 266 267 2 0.4 2656.3 1.3X +Coalesce Num Partitions: 500 Num Hosts: 1 609 619 15 0.2 6089.0 0.6X +Coalesce Num Partitions: 500 Num Hosts: 5 338 343 6 0.3 3383.0 1.0X +Coalesce Num Partitions: 500 Num Hosts: 10 303 306 3 0.3 3029.4 1.1X +Coalesce Num Partitions: 500 Num Hosts: 20 286 288 2 0.4 2855.9 1.2X +Coalesce Num Partitions: 500 Num Hosts: 40 279 282 4 0.4 2793.3 1.2X +Coalesce Num Partitions: 500 Num Hosts: 80 273 275 3 0.4 2725.9 1.3X +Coalesce Num Partitions: 1000 Num Hosts: 1 951 955 4 0.1 9514.1 0.4X +Coalesce Num Partitions: 1000 Num Hosts: 5 421 429 8 0.2 4211.3 0.8X +Coalesce Num Partitions: 1000 Num Hosts: 10 347 352 4 0.3 3473.5 1.0X +Coalesce Num Partitions: 1000 Num Hosts: 20 309 312 5 0.3 3087.5 1.1X +Coalesce Num Partitions: 1000 Num Hosts: 40 290 294 6 0.3 2896.4 1.2X +Coalesce Num Partitions: 1000 Num Hosts: 80 281 286 5 0.4 2811.3 1.2X +Coalesce Num Partitions: 5000 Num Hosts: 1 3928 3950 27 0.0 39278.0 0.1X +Coalesce Num Partitions: 5000 Num Hosts: 5 1373 1389 27 0.1 13725.2 0.3X +Coalesce Num Partitions: 5000 Num Hosts: 10 812 827 13 0.1 8123.3 0.4X +Coalesce Num Partitions: 5000 Num Hosts: 20 530 540 9 0.2 5299.1 0.6X +Coalesce Num Partitions: 5000 Num Hosts: 40 421 425 5 0.2 4210.5 0.8X +Coalesce Num Partitions: 5000 Num Hosts: 80 335 344 12 0.3 3353.7 1.0X +Coalesce Num Partitions: 10000 Num Hosts: 1 7116 7120 4 0.0 71159.0 0.0X +Coalesce Num Partitions: 10000 Num Hosts: 5 2539 2598 51 0.0 25390.1 0.1X +Coalesce Num Partitions: 10000 Num Hosts: 10 1393 1432 34 0.1 13928.1 0.2X +Coalesce Num Partitions: 10000 Num Hosts: 20 833 1009 303 0.1 8329.2 0.4X +Coalesce Num Partitions: 10000 Num Hosts: 40 562 563 3 0.2 5615.2 0.6X +Coalesce Num Partitions: 10000 Num Hosts: 80 420 426 7 0.2 4204.0 0.8X + + diff --git a/core/benchmarks/CoalescedRDDBenchmark-results.txt b/core/benchmarks/CoalescedRDDBenchmark-results.txt index dd63b0adea4f2..f1b867951a074 100644 --- a/core/benchmarks/CoalescedRDDBenchmark-results.txt +++ b/core/benchmarks/CoalescedRDDBenchmark-results.txt @@ -2,39 +2,39 @@ Coalesced RDD , large scale ================================================================================================ -Java HotSpot(TM) 64-Bit Server VM 1.8.0_201-b09 on Windows 10 10.0 -Intel64 Family 6 Model 63 Stepping 2, GenuineIntel +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Coalesced RDD: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Coalesce Num Partitions: 100 Num Hosts: 1 346 364 24 0.3 3458.9 1.0X -Coalesce Num Partitions: 100 Num Hosts: 5 258 264 6 0.4 2579.0 1.3X -Coalesce Num Partitions: 100 Num Hosts: 10 242 249 7 0.4 2415.2 1.4X -Coalesce Num Partitions: 100 Num Hosts: 20 237 242 7 0.4 2371.7 1.5X -Coalesce Num Partitions: 100 Num Hosts: 40 230 231 1 0.4 2299.8 1.5X -Coalesce Num Partitions: 100 Num Hosts: 80 222 233 14 0.4 2223.0 1.6X -Coalesce Num Partitions: 500 Num Hosts: 1 659 665 5 0.2 6590.4 0.5X -Coalesce Num Partitions: 500 Num Hosts: 5 340 381 47 0.3 3395.2 1.0X -Coalesce Num Partitions: 500 Num Hosts: 10 279 307 47 0.4 2788.3 1.2X -Coalesce Num Partitions: 500 Num Hosts: 20 259 261 2 0.4 2591.9 1.3X -Coalesce Num Partitions: 500 Num Hosts: 40 241 250 15 0.4 2406.5 1.4X -Coalesce Num Partitions: 500 Num Hosts: 80 235 237 3 0.4 2349.9 1.5X -Coalesce Num Partitions: 1000 Num Hosts: 1 1050 1053 4 0.1 10503.2 0.3X -Coalesce Num Partitions: 1000 Num Hosts: 5 405 407 2 0.2 4049.5 0.9X -Coalesce Num Partitions: 1000 Num Hosts: 10 320 322 2 0.3 3202.7 1.1X -Coalesce Num Partitions: 1000 Num Hosts: 20 276 277 0 0.4 2762.3 1.3X -Coalesce Num Partitions: 1000 Num Hosts: 40 257 260 5 0.4 2571.2 1.3X -Coalesce Num Partitions: 1000 Num Hosts: 80 245 252 13 0.4 2448.9 1.4X -Coalesce Num Partitions: 5000 Num Hosts: 1 3099 3145 55 0.0 30988.6 0.1X -Coalesce Num Partitions: 5000 Num Hosts: 5 1037 1050 20 0.1 10374.4 0.3X -Coalesce Num Partitions: 5000 Num Hosts: 10 626 633 8 0.2 6261.8 0.6X -Coalesce Num Partitions: 5000 Num Hosts: 20 426 431 5 0.2 4258.6 0.8X -Coalesce Num Partitions: 5000 Num Hosts: 40 328 341 22 0.3 3275.4 1.1X -Coalesce Num Partitions: 5000 Num Hosts: 80 272 275 4 0.4 2721.4 1.3X -Coalesce Num Partitions: 10000 Num Hosts: 1 5516 5526 9 0.0 55156.8 0.1X -Coalesce Num Partitions: 10000 Num Hosts: 5 1956 1992 48 0.1 19560.9 0.2X -Coalesce Num Partitions: 10000 Num Hosts: 10 1045 1057 18 0.1 10447.4 0.3X -Coalesce Num Partitions: 10000 Num Hosts: 20 637 658 24 0.2 6373.2 0.5X -Coalesce Num Partitions: 10000 Num Hosts: 40 431 448 15 0.2 4312.9 0.8X -Coalesce Num Partitions: 10000 Num Hosts: 80 326 328 2 0.3 3263.4 1.1X +Coalesce Num Partitions: 100 Num Hosts: 1 395 401 9 0.3 3952.3 1.0X +Coalesce Num Partitions: 100 Num Hosts: 5 296 344 42 0.3 2963.2 1.3X +Coalesce Num Partitions: 100 Num Hosts: 10 294 308 15 0.3 2941.7 1.3X +Coalesce Num Partitions: 100 Num Hosts: 20 316 328 13 0.3 3155.2 1.3X +Coalesce Num Partitions: 100 Num Hosts: 40 294 316 36 0.3 2940.3 1.3X +Coalesce Num Partitions: 100 Num Hosts: 80 292 324 30 0.3 2922.2 1.4X +Coalesce Num Partitions: 500 Num Hosts: 1 629 687 61 0.2 6292.4 0.6X +Coalesce Num Partitions: 500 Num Hosts: 5 354 378 42 0.3 3541.7 1.1X +Coalesce Num Partitions: 500 Num Hosts: 10 318 338 29 0.3 3179.8 1.2X +Coalesce Num Partitions: 500 Num Hosts: 20 306 317 11 0.3 3059.2 1.3X +Coalesce Num Partitions: 500 Num Hosts: 40 294 311 28 0.3 2941.6 1.3X +Coalesce Num Partitions: 500 Num Hosts: 80 288 309 34 0.3 2883.9 1.4X +Coalesce Num Partitions: 1000 Num Hosts: 1 956 978 20 0.1 9562.2 0.4X +Coalesce Num Partitions: 1000 Num Hosts: 5 431 452 36 0.2 4306.2 0.9X +Coalesce Num Partitions: 1000 Num Hosts: 10 358 379 23 0.3 3581.1 1.1X +Coalesce Num Partitions: 1000 Num Hosts: 20 324 347 20 0.3 3236.7 1.2X +Coalesce Num Partitions: 1000 Num Hosts: 40 312 333 20 0.3 3116.8 1.3X +Coalesce Num Partitions: 1000 Num Hosts: 80 307 342 32 0.3 3068.4 1.3X +Coalesce Num Partitions: 5000 Num Hosts: 1 3895 3906 12 0.0 38946.8 0.1X +Coalesce Num Partitions: 5000 Num Hosts: 5 1388 1401 19 0.1 13881.7 0.3X +Coalesce Num Partitions: 5000 Num Hosts: 10 806 839 57 0.1 8063.7 0.5X +Coalesce Num Partitions: 5000 Num Hosts: 20 546 573 44 0.2 5462.6 0.7X +Coalesce Num Partitions: 5000 Num Hosts: 40 413 418 5 0.2 4134.7 1.0X +Coalesce Num Partitions: 5000 Num Hosts: 80 345 365 23 0.3 3448.1 1.1X +Coalesce Num Partitions: 10000 Num Hosts: 1 6933 6966 55 0.0 69328.8 0.1X +Coalesce Num Partitions: 10000 Num Hosts: 5 2455 2499 69 0.0 24551.7 0.2X +Coalesce Num Partitions: 10000 Num Hosts: 10 1352 1392 34 0.1 13520.2 0.3X +Coalesce Num Partitions: 10000 Num Hosts: 20 815 853 50 0.1 8147.5 0.5X +Coalesce Num Partitions: 10000 Num Hosts: 40 558 581 28 0.2 5578.0 0.7X +Coalesce Num Partitions: 10000 Num Hosts: 80 416 423 5 0.2 4163.3 0.9X diff --git a/core/benchmarks/KryoBenchmark-jdk11-results.txt b/core/benchmarks/KryoBenchmark-jdk11-results.txt new file mode 100644 index 0000000000000..27f0b8f59f47a --- /dev/null +++ b/core/benchmarks/KryoBenchmark-jdk11-results.txt @@ -0,0 +1,28 @@ +================================================================================================ +Benchmark Kryo Unsafe vs safe Serialization +================================================================================================ + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Benchmark Kryo Unsafe vs safe Serialization: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +basicTypes: Int with unsafe:true 275 288 14 3.6 275.2 1.0X +basicTypes: Long with unsafe:true 331 336 13 3.0 330.9 0.8X +basicTypes: Float with unsafe:true 304 305 1 3.3 304.4 0.9X +basicTypes: Double with unsafe:true 328 332 3 3.0 328.1 0.8X +Array: Int with unsafe:true 4 4 0 252.8 4.0 69.6X +Array: Long with unsafe:true 6 6 0 161.5 6.2 44.5X +Array: Float with unsafe:true 4 4 0 264.6 3.8 72.8X +Array: Double with unsafe:true 6 7 0 160.5 6.2 44.2X +Map of string->Double with unsafe:true 52 52 0 19.3 51.8 5.3X +basicTypes: Int with unsafe:false 344 345 1 2.9 344.3 0.8X +basicTypes: Long with unsafe:false 372 373 1 2.7 372.3 0.7X +basicTypes: Float with unsafe:false 333 334 1 3.0 333.4 0.8X +basicTypes: Double with unsafe:false 344 345 0 2.9 344.3 0.8X +Array: Int with unsafe:false 25 25 0 40.8 24.5 11.2X +Array: Long with unsafe:false 37 37 1 27.3 36.7 7.5X +Array: Float with unsafe:false 11 11 0 92.1 10.9 25.4X +Array: Double with unsafe:false 17 18 0 58.3 17.2 16.0X +Map of string->Double with unsafe:false 51 52 1 19.4 51.5 5.3X + + diff --git a/core/benchmarks/KryoBenchmark-results.txt b/core/benchmarks/KryoBenchmark-results.txt index 91e22f3afc14f..49791e6e87e3a 100644 --- a/core/benchmarks/KryoBenchmark-results.txt +++ b/core/benchmarks/KryoBenchmark-results.txt @@ -2,28 +2,27 @@ Benchmark Kryo Unsafe vs safe Serialization ================================================================================================ -Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.13.6 -Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz - -Benchmark Kryo Unsafe vs safe Serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -basicTypes: Int with unsafe:true 138 / 149 7.2 138.0 1.0X -basicTypes: Long with unsafe:true 168 / 173 6.0 167.7 0.8X -basicTypes: Float with unsafe:true 153 / 174 6.5 153.1 0.9X -basicTypes: Double with unsafe:true 161 / 185 6.2 161.1 0.9X -Array: Int with unsafe:true 2 / 3 409.7 2.4 56.5X -Array: Long with unsafe:true 4 / 5 232.5 4.3 32.1X -Array: Float with unsafe:true 3 / 4 367.3 2.7 50.7X -Array: Double with unsafe:true 4 / 5 228.5 4.4 31.5X -Map of string->Double with unsafe:true 38 / 45 26.5 37.8 3.7X -basicTypes: Int with unsafe:false 176 / 187 5.7 175.9 0.8X -basicTypes: Long with unsafe:false 191 / 203 5.2 191.2 0.7X -basicTypes: Float with unsafe:false 166 / 176 6.0 166.2 0.8X -basicTypes: Double with unsafe:false 174 / 190 5.7 174.3 0.8X -Array: Int with unsafe:false 19 / 26 52.9 18.9 7.3X -Array: Long with unsafe:false 27 / 31 37.7 26.5 5.2X -Array: Float with unsafe:false 8 / 10 124.3 8.0 17.2X -Array: Double with unsafe:false 12 / 13 83.6 12.0 11.5X -Map of string->Double with unsafe:false 38 / 42 26.1 38.3 3.6X +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Benchmark Kryo Unsafe vs safe Serialization: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +basicTypes: Int with unsafe:true 269 290 23 3.7 269.0 1.0X +basicTypes: Long with unsafe:true 294 295 1 3.4 293.8 0.9X +basicTypes: Float with unsafe:true 300 301 1 3.3 300.4 0.9X +basicTypes: Double with unsafe:true 304 305 1 3.3 304.0 0.9X +Array: Int with unsafe:true 5 6 1 193.5 5.2 52.0X +Array: Long with unsafe:true 8 9 1 131.2 7.6 35.3X +Array: Float with unsafe:true 6 6 0 163.5 6.1 44.0X +Array: Double with unsafe:true 9 10 0 108.8 9.2 29.3X +Map of string->Double with unsafe:true 54 54 1 18.7 53.6 5.0X +basicTypes: Int with unsafe:false 326 327 1 3.1 326.2 0.8X +basicTypes: Long with unsafe:false 353 354 1 2.8 353.3 0.8X +basicTypes: Float with unsafe:false 325 327 1 3.1 325.1 0.8X +basicTypes: Double with unsafe:false 335 336 1 3.0 335.0 0.8X +Array: Int with unsafe:false 27 28 1 36.7 27.2 9.9X +Array: Long with unsafe:false 40 41 1 25.0 40.0 6.7X +Array: Float with unsafe:false 12 13 1 80.8 12.4 21.7X +Array: Double with unsafe:false 21 21 1 48.6 20.6 13.1X +Map of string->Double with unsafe:false 56 57 1 17.8 56.1 4.8X diff --git a/core/benchmarks/KryoSerializerBenchmark-jdk11-results.txt b/core/benchmarks/KryoSerializerBenchmark-jdk11-results.txt new file mode 100644 index 0000000000000..6b148bde12d36 --- /dev/null +++ b/core/benchmarks/KryoSerializerBenchmark-jdk11-results.txt @@ -0,0 +1,12 @@ +================================================================================================ +Benchmark KryoPool vs old"pool of 1" implementation +================================================================================================ + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Benchmark KryoPool vs old"pool of 1" implementation: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +KryoPool:true 6208 8374 NaN 0.0 12416876.6 1.0X +KryoPool:false 9084 11577 724 0.0 18168947.4 0.7X + + diff --git a/core/benchmarks/KryoSerializerBenchmark-results.txt b/core/benchmarks/KryoSerializerBenchmark-results.txt index c3ce336d93241..609f3298cbc00 100644 --- a/core/benchmarks/KryoSerializerBenchmark-results.txt +++ b/core/benchmarks/KryoSerializerBenchmark-results.txt @@ -1,12 +1,12 @@ ================================================================================================ -Benchmark KryoPool vs "pool of 1" +Benchmark KryoPool vs old"pool of 1" implementation ================================================================================================ -Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.14 -Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz -Benchmark KryoPool vs "pool of 1": Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -KryoPool:true 2682 / 3425 0.0 5364627.9 1.0X -KryoPool:false 8176 / 9292 0.0 16351252.2 0.3X +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Benchmark KryoPool vs old"pool of 1" implementation: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +KryoPool:true 6012 7586 NaN 0.0 12023020.2 1.0X +KryoPool:false 9289 11566 909 0.0 18578683.1 0.6X diff --git a/core/benchmarks/PropertiesCloneBenchmark-jdk11-results.txt b/core/benchmarks/PropertiesCloneBenchmark-jdk11-results.txt new file mode 100644 index 0000000000000..605b856d53382 --- /dev/null +++ b/core/benchmarks/PropertiesCloneBenchmark-jdk11-results.txt @@ -0,0 +1,40 @@ +================================================================================================ +Properties Cloning +================================================================================================ + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Empty Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SerializationUtils.clone 0 0 0 0.1 11539.0 1.0X +Utils.cloneProperties 0 0 0 1.7 572.0 20.2X + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +System Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SerializationUtils.clone 0 0 0 0.0 217514.0 1.0X +Utils.cloneProperties 0 0 0 0.2 5387.0 40.4X + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Small Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SerializationUtils.clone 1 1 0 0.0 634574.0 1.0X +Utils.cloneProperties 0 0 0 0.3 3082.0 205.9X + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Medium Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SerializationUtils.clone 3 3 0 0.0 2576565.0 1.0X +Utils.cloneProperties 0 0 0 0.1 16071.0 160.3X + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Large Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SerializationUtils.clone 5 5 0 0.0 5027248.0 1.0X +Utils.cloneProperties 0 0 0 0.0 31842.0 157.9X + + diff --git a/core/benchmarks/PropertiesCloneBenchmark-results.txt b/core/benchmarks/PropertiesCloneBenchmark-results.txt new file mode 100644 index 0000000000000..5d332a147c698 --- /dev/null +++ b/core/benchmarks/PropertiesCloneBenchmark-results.txt @@ -0,0 +1,40 @@ +================================================================================================ +Properties Cloning +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Empty Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SerializationUtils.clone 0 0 0 0.1 13640.0 1.0X +Utils.cloneProperties 0 0 0 1.6 608.0 22.4X + +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +System Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SerializationUtils.clone 0 0 0 0.0 238968.0 1.0X +Utils.cloneProperties 0 0 0 0.4 2318.0 103.1X + +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Small Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SerializationUtils.clone 1 1 0 0.0 725849.0 1.0X +Utils.cloneProperties 0 0 0 0.3 2900.0 250.3X + +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Medium Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SerializationUtils.clone 3 3 0 0.0 2999676.0 1.0X +Utils.cloneProperties 0 0 0 0.1 11734.0 255.6X + +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Large Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SerializationUtils.clone 6 6 1 0.0 5846410.0 1.0X +Utils.cloneProperties 0 0 0 0.0 22405.0 260.9X + + diff --git a/core/benchmarks/XORShiftRandomBenchmark-jdk11-results.txt b/core/benchmarks/XORShiftRandomBenchmark-jdk11-results.txt new file mode 100644 index 0000000000000..9aa10e4835a2f --- /dev/null +++ b/core/benchmarks/XORShiftRandomBenchmark-jdk11-results.txt @@ -0,0 +1,44 @@ +================================================================================================ +Pseudo random +================================================================================================ + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +nextInt: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +java.util.Random 1362 1362 0 73.4 13.6 1.0X +XORShiftRandom 227 227 0 440.6 2.3 6.0X + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +nextLong: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +java.util.Random 2725 2726 1 36.7 27.3 1.0X +XORShiftRandom 694 694 1 144.1 6.9 3.9X + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +nextDouble: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +java.util.Random 2727 2728 0 36.7 27.3 1.0X +XORShiftRandom 693 694 0 144.2 6.9 3.9X + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +nextGaussian: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +java.util.Random 7012 7016 4 14.3 70.1 1.0X +XORShiftRandom 6065 6067 1 16.5 60.7 1.2X + + +================================================================================================ +hash seed +================================================================================================ + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Hash seed: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +XORShiftRandom.hashSeed 36 37 1 276.5 3.6 1.0X + + diff --git a/core/benchmarks/XORShiftRandomBenchmark-results.txt b/core/benchmarks/XORShiftRandomBenchmark-results.txt index 1140489e4a7f3..4b069878b2e9b 100644 --- a/core/benchmarks/XORShiftRandomBenchmark-results.txt +++ b/core/benchmarks/XORShiftRandomBenchmark-results.txt @@ -2,43 +2,43 @@ Pseudo random ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -nextInt: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -java.util.Random 1362 / 1362 73.4 13.6 1.0X -XORShiftRandom 227 / 227 440.6 2.3 6.0X +nextInt: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +java.util.Random 1362 1396 59 73.4 13.6 1.0X +XORShiftRandom 227 227 0 440.7 2.3 6.0X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -nextLong: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -java.util.Random 2732 / 2732 36.6 27.3 1.0X -XORShiftRandom 629 / 629 159.0 6.3 4.3X +nextLong: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +java.util.Random 2732 2732 1 36.6 27.3 1.0X +XORShiftRandom 630 630 1 158.7 6.3 4.3X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -nextDouble: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -java.util.Random 2730 / 2730 36.6 27.3 1.0X -XORShiftRandom 629 / 629 159.0 6.3 4.3X +nextDouble: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +java.util.Random 2731 2732 1 36.6 27.3 1.0X +XORShiftRandom 630 630 0 158.8 6.3 4.3X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -nextGaussian: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -java.util.Random 10288 / 10288 9.7 102.9 1.0X -XORShiftRandom 6351 / 6351 15.7 63.5 1.6X +nextGaussian: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +java.util.Random 8895 8899 4 11.2 88.9 1.0X +XORShiftRandom 5049 5052 5 19.8 50.5 1.8X ================================================================================================ hash seed ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -Hash seed: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -XORShiftRandom.hashSeed 1193 / 1195 8.4 119.3 1.0X +Hash seed: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +XORShiftRandom.hashSeed 67 68 1 148.8 6.7 1.0X diff --git a/core/pom.xml b/core/pom.xml index 42fc2c4b3a287..38eb8adac500e 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -384,6 +384,11 @@ curator-test test + + org.apache.hadoop + hadoop-minikdc + test + net.razorvine pyrolite @@ -551,6 +556,15 @@ + + scala-2.13 + + + org.scala-lang.modules + scala-parallel-collections_${scala.binary.version} + + + diff --git a/core/src/main/java/org/apache/spark/ExecutorPlugin.java b/core/src/main/java/org/apache/spark/ExecutorPlugin.java index f86520c81df33..b25c46266247e 100644 --- a/core/src/main/java/org/apache/spark/ExecutorPlugin.java +++ b/core/src/main/java/org/apache/spark/ExecutorPlugin.java @@ -40,12 +40,15 @@ public interface ExecutorPlugin { * Initialize the executor plugin. * *

Each executor will, during its initialization, invoke this method on each - * plugin provided in the spark.executor.plugins configuration.

+ * plugin provided in the spark.executor.plugins configuration. The Spark executor + * will wait on the completion of the execution of the init method.

* *

Plugins should create threads in their implementation of this method for * any polling, blocking, or intensive computation.

+ * + * @param pluginContext Context information for the executor where the plugin is running. */ - default void init() {} + default void init(ExecutorPluginContext pluginContext) {} /** * Clean up and terminate this plugin. diff --git a/core/src/main/java/org/apache/spark/ExecutorPluginContext.java b/core/src/main/java/org/apache/spark/ExecutorPluginContext.java new file mode 100644 index 0000000000000..8f018732b8217 --- /dev/null +++ b/core/src/main/java/org/apache/spark/ExecutorPluginContext.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import com.codahale.metrics.MetricRegistry; +import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.annotation.Private; + +/** + * Encapsulates information about the executor when initializing {@link ExecutorPlugin} instances. + */ +@DeveloperApi +public class ExecutorPluginContext { + + public final MetricRegistry metricRegistry; + public final SparkConf sparkConf; + public final String executorId; + public final String executorHostName; + public final boolean isLocal; + + @Private + public ExecutorPluginContext( + MetricRegistry registry, + SparkConf conf, + String id, + String hostName, + boolean local) { + metricRegistry = registry; + sparkConf = conf; + executorId = id; + executorHostName = hostName; + isLocal = local; + } + +} diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index 92bf0ecc1b5cb..a1e29a8c873da 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -51,7 +51,6 @@ public NioBufferedFileInputStream(File file) throws IOException { /** * Checks weather data is left to be read from the input stream. * @return true if data is left, false otherwise - * @throws IOException */ private boolean refill() throws IOException { if (!byteBuffer.hasRemaining()) { diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 4bfd2d358f36f..9a9d0c7946549 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -54,7 +54,7 @@ public MemoryMode getMode() { /** * Returns the size of used memory in bytes. */ - protected long getUsed() { + public long getUsed() { return used; } @@ -78,7 +78,6 @@ public void spill() throws IOException { * @param size the amount of memory should be released * @param trigger the MemoryConsumer that trigger this spilling * @return the amount of released memory in bytes - * @throws IOException */ public abstract long spill(long size, MemoryConsumer trigger) throws IOException; diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java index 70c112b78911d..d30f3dad3c940 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java @@ -18,6 +18,7 @@ package org.apache.spark.shuffle.api; import java.io.IOException; +import java.util.Optional; import org.apache.spark.annotation.Private; @@ -39,17 +40,31 @@ public interface ShuffleExecutorComponents { /** * Called once per map task to create a writer that will be responsible for persisting all the * partitioned bytes written by that map task. - * @param shuffleId Unique identifier for the shuffle the map task is a part of - * @param mapId Within the shuffle, the identifier of the map task - * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task - * with the same (shuffleId, mapId) pair can be distinguished by the - * different values of mapTaskAttemptId. + * + * @param shuffleId Unique identifier for the shuffle the map task is a part of + * @param mapId An ID of the map task. The ID is unique within this Spark application. * @param numPartitions The number of partitions that will be written by the map task. Some of -* these partitions may be empty. + * these partitions may be empty. */ ShuffleMapOutputWriter createMapOutputWriter( int shuffleId, - int mapId, - long mapTaskAttemptId, + long mapId, int numPartitions) throws IOException; + + /** + * An optional extension for creating a map output writer that can optimize the transfer of a + * single partition file, as the entire result of a map task, to the backing store. + *

+ * Most implementations should return the default {@link Optional#empty()} to indicate that + * they do not support this optimization. This primarily is for backwards-compatibility in + * preserving an optimization in the local disk shuffle storage implementation. + * + * @param shuffleId Unique identifier for the shuffle the map task is a part of + * @param mapId An ID of the map task. The ID is unique within this Spark application. + */ + default Optional createSingleFileMapOutputWriter( + int shuffleId, + long mapId) throws IOException { + return Optional.empty(); + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java index 7fac00b7fbc3f..21abe9a57cd25 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -39,7 +39,7 @@ public interface ShuffleMapOutputWriter { * for the same partition within any given map task. The partition identifier will be in the * range of precisely 0 (inclusive) to numPartitions (exclusive), where numPartitions was * provided upon the creation of this map output writer via - * {@link ShuffleExecutorComponents#createMapOutputWriter(int, int, long, int)}. + * {@link ShuffleExecutorComponents#createMapOutputWriter(int, long, int)}. *

* Calls to this method will be invoked with monotonically increasing reducePartitionIds; each * call to this method will be called with a reducePartitionId that is strictly greater than diff --git a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java new file mode 100644 index 0000000000000..cad8dcfda52bc --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import java.io.File; +import java.io.IOException; + +import org.apache.spark.annotation.Private; + +/** + * Optional extension for partition writing that is optimized for transferring a single + * file to the backing store. + */ +@Private +public interface SingleSpillShuffleMapOutputWriter { + + /** + * Transfer a file that contains the bytes of all the partitions written by this map task. + */ + void transferMapSpillFile(File mapOutputFile, long[] partitionLengths) throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index f75e932860f90..dc157eaa3b253 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -85,8 +85,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final Partitioner partitioner; private final ShuffleWriteMetricsReporter writeMetrics; private final int shuffleId; - private final int mapId; - private final long mapTaskAttemptId; + private final long mapId; private final Serializer serializer; private final ShuffleExecutorComponents shuffleExecutorComponents; @@ -106,8 +105,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { BypassMergeSortShuffleWriter( BlockManager blockManager, BypassMergeSortShuffleHandle handle, - int mapId, - long mapTaskAttemptId, + long mapId, SparkConf conf, ShuffleWriteMetricsReporter writeMetrics, ShuffleExecutorComponents shuffleExecutorComponents) { @@ -117,7 +115,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.blockManager = blockManager; final ShuffleDependency dep = handle.dependency(); this.mapId = mapId; - this.mapTaskAttemptId = mapTaskAttemptId; this.shuffleId = dep.shuffleId(); this.partitioner = dep.partitioner(); this.numPartitions = partitioner.numPartitions(); @@ -130,11 +127,12 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { public void write(Iterator> records) throws IOException { assert (partitionWriters == null); ShuffleMapOutputWriter mapOutputWriter = shuffleExecutorComponents - .createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions); + .createMapOutputWriter(shuffleId, mapId, numPartitions); try { if (!records.hasNext()) { partitionLengths = mapOutputWriter.commitAllPartitions(); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, mapId); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -167,7 +165,8 @@ public void write(Iterator> records) throws IOException { } partitionLengths = writePartitionedData(mapOutputWriter); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, mapId); } catch (Exception e) { try { mapOutputWriter.abort(e); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 024756087bf7f..833744f4777ce 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -423,7 +423,6 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p * * @return metadata for the spill files written by this sorter. If no records were ever inserted * into this sorter, then this will return an empty array. - * @throws IOException */ public SpillInfo[] closeAndGetSpills() throws IOException { if (inMemSorter != null) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 9d05f03613ce9..d09282e61a9c7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -17,9 +17,12 @@ package org.apache.spark.shuffle.sort; +import java.nio.channels.Channels; +import java.util.Optional; import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; import java.util.Iterator; import scala.Option; @@ -31,7 +34,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; import com.google.common.io.Closeables; -import com.google.common.io.Files; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,8 +43,6 @@ import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.NioBufferedFileInputStream; -import org.apache.commons.io.output.CloseShieldOutputStream; -import org.apache.commons.io.output.CountingOutputStream; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; @@ -50,8 +50,12 @@ import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.WritableByteChannelWrapper; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; @@ -65,23 +69,21 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); @VisibleForTesting - static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096; static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; private final BlockManager blockManager; - private final IndexShuffleBlockResolver shuffleBlockResolver; private final TaskMemoryManager memoryManager; private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetricsReporter writeMetrics; + private final ShuffleExecutorComponents shuffleExecutorComponents; private final int shuffleId; - private final int mapId; + private final long mapId; private final TaskContext taskContext; private final SparkConf sparkConf; private final boolean transferToEnabled; private final int initialSortBufferSize; private final int inputBufferSizeInBytes; - private final int outputBufferSizeInBytes; @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; @@ -103,27 +105,15 @@ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream */ private boolean stopping = false; - private class CloseAndFlushShieldOutputStream extends CloseShieldOutputStream { - - CloseAndFlushShieldOutputStream(OutputStream outputStream) { - super(outputStream); - } - - @Override - public void flush() { - // do nothing - } - } - public UnsafeShuffleWriter( BlockManager blockManager, - IndexShuffleBlockResolver shuffleBlockResolver, TaskMemoryManager memoryManager, SerializedShuffleHandle handle, - int mapId, + long mapId, TaskContext taskContext, SparkConf sparkConf, - ShuffleWriteMetricsReporter writeMetrics) throws IOException { + ShuffleWriteMetricsReporter writeMetrics, + ShuffleExecutorComponents shuffleExecutorComponents) { final int numPartitions = handle.dependency().partitioner().numPartitions(); if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( @@ -132,7 +122,6 @@ public UnsafeShuffleWriter( " reduce partitions"); } this.blockManager = blockManager; - this.shuffleBlockResolver = shuffleBlockResolver; this.memoryManager = memoryManager; this.mapId = mapId; final ShuffleDependency dep = handle.dependency(); @@ -140,6 +129,7 @@ public UnsafeShuffleWriter( this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); this.writeMetrics = writeMetrics; + this.shuffleExecutorComponents = shuffleExecutorComponents; this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); @@ -147,8 +137,6 @@ public UnsafeShuffleWriter( (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()); this.inputBufferSizeInBytes = (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; - this.outputBufferSizeInBytes = - (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; open(); } @@ -231,25 +219,17 @@ void closeAndWriteOutput() throws IOException { final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; final long[] partitionLengths; - final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); - final File tmp = Utils.tempFileWith(output); try { - try { - partitionLengths = mergeSpills(spills, tmp); - } finally { - for (SpillInfo spill : spills) { - if (spill.file.exists() && ! spill.file.delete()) { - logger.error("Error while deleting spill file {}", spill.file.getPath()); - } - } - } - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + partitionLengths = mergeSpills(spills); } finally { - if (tmp.exists() && !tmp.delete()) { - logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); + for (SpillInfo spill : spills) { + if (spill.file.exists() && !spill.file.delete()) { + logger.error("Error while deleting spill file {}", spill.file.getPath()); + } } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, mapId); } @VisibleForTesting @@ -281,137 +261,153 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { + private long[] mergeSpills(SpillInfo[] spills) throws IOException { + long[] partitionLengths; + if (spills.length == 0) { + final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents + .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions()); + return mapWriter.commitAllPartitions(); + } else if (spills.length == 1) { + Optional maybeSingleFileWriter = + shuffleExecutorComponents.createSingleFileMapOutputWriter(shuffleId, mapId); + if (maybeSingleFileWriter.isPresent()) { + // Here, we don't need to perform any metrics updates because the bytes written to this + // output file would have already been counted as shuffle bytes written. + partitionLengths = spills[0].partitionLengths; + maybeSingleFileWriter.get().transferMapSpillFile(spills[0].file, partitionLengths); + } else { + partitionLengths = mergeSpillsUsingStandardWriter(spills); + } + } else { + partitionLengths = mergeSpillsUsingStandardWriter(spills); + } + return partitionLengths; + } + + private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOException { + long[] partitionLengths; final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS()); final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); final boolean fastMergeEnabled = - (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNDAFE_FAST_MERGE_ENABLE()); + (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE()); final boolean fastMergeIsSupported = !compressionEnabled || - CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); + CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); + final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents + .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions()); try { - if (spills.length == 0) { - new FileOutputStream(outputFile).close(); // Create an empty file - return new long[partitioner.numPartitions()]; - } else if (spills.length == 1) { - // Here, we don't need to perform any metrics updates because the bytes written to this - // output file would have already been counted as shuffle bytes written. - Files.move(spills[0].file, outputFile); - return spills[0].partitionLengths; - } else { - final long[] partitionLengths; - // There are multiple spills to merge, so none of these spill files' lengths were counted - // towards our shuffle write count or shuffle write time. If we use the slow merge path, - // then the final output file's size won't necessarily be equal to the sum of the spill - // files' sizes. To guard against this case, we look at the output file's actual size when - // computing shuffle bytes written. - // - // We allow the individual merge methods to report their own IO times since different merge - // strategies use different IO techniques. We count IO during merge towards the shuffle - // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" - // branch in ExternalSorter. - if (fastMergeEnabled && fastMergeIsSupported) { - // Compression is disabled or we are using an IO compression codec that supports - // decompression of concatenated compressed streams, so we can perform a fast spill merge - // that doesn't need to interpret the spilled bytes. - if (transferToEnabled && !encryptionEnabled) { - logger.debug("Using transferTo-based fast merge"); - partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); - } else { - logger.debug("Using fileStream-based fast merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); - } + // There are multiple spills to merge, so none of these spill files' lengths were counted + // towards our shuffle write count or shuffle write time. If we use the slow merge path, + // then the final output file's size won't necessarily be equal to the sum of the spill + // files' sizes. To guard against this case, we look at the output file's actual size when + // computing shuffle bytes written. + // + // We allow the individual merge methods to report their own IO times since different merge + // strategies use different IO techniques. We count IO during merge towards the shuffle + // write time, which appears to be consistent with the "not bypassing merge-sort" branch in + // ExternalSorter. + if (fastMergeEnabled && fastMergeIsSupported) { + // Compression is disabled or we are using an IO compression codec that supports + // decompression of concatenated compressed streams, so we can perform a fast spill merge + // that doesn't need to interpret the spilled bytes. + if (transferToEnabled && !encryptionEnabled) { + logger.debug("Using transferTo-based fast merge"); + mergeSpillsWithTransferTo(spills, mapWriter); } else { - logger.debug("Using slow merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); + logger.debug("Using fileStream-based fast merge"); + mergeSpillsWithFileStream(spills, mapWriter, null); } - // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has - // in-memory records, we write out the in-memory records to a file but do not count that - // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs - // to be counted as shuffle write, but this will lead to double-counting of the final - // SpillInfo's bytes. - writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); - writeMetrics.incBytesWritten(outputFile.length()); - return partitionLengths; + } else { + logger.debug("Using slow merge"); + mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); } - } catch (IOException e) { - if (outputFile.exists() && !outputFile.delete()) { - logger.error("Unable to delete output file {}", outputFile.getPath()); + // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has + // in-memory records, we write out the in-memory records to a file but do not count that + // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs + // to be counted as shuffle write, but this will lead to double-counting of the final + // SpillInfo's bytes. + writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); + partitionLengths = mapWriter.commitAllPartitions(); + } catch (Exception e) { + try { + mapWriter.abort(e); + } catch (Exception e2) { + logger.warn("Failed to abort writing the map output.", e2); + e.addSuppressed(e2); } throw e; } + return partitionLengths; } /** * Merges spill files using Java FileStreams. This code path is typically slower than * the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], - * File)}, and it's mostly used in cases where the IO compression codec does not support - * concatenation of compressed data, when encryption is enabled, or when users have - * explicitly disabled use of {@code transferTo} in order to work around kernel bugs. + * ShuffleMapOutputWriter)}, and it's mostly used in cases where the IO compression codec + * does not support concatenation of compressed data, when encryption is enabled, or when + * users have explicitly disabled use of {@code transferTo} in order to work around kernel bugs. * This code path might also be faster in cases where individual partition size in a spill * is small and UnsafeShuffleWriter#mergeSpillsWithTransferTo method performs many small * disk ios which is inefficient. In those case, Using large buffers for input and output * files helps reducing the number of disk ios, making the file merging faster. * * @param spills the spills to merge. - * @param outputFile the file to write the merged data to. + * @param mapWriter the map output writer to use for output. * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. * @return the partition lengths in the merged file. */ - private long[] mergeSpillsWithFileStream( + private void mergeSpillsWithFileStream( SpillInfo[] spills, - File outputFile, + ShuffleMapOutputWriter mapWriter, @Nullable CompressionCodec compressionCodec) throws IOException { - assert (spills.length >= 2); final int numPartitions = partitioner.numPartitions(); - final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new InputStream[spills.length]; - final OutputStream bos = new BufferedOutputStream( - new FileOutputStream(outputFile), - outputBufferSizeInBytes); - // Use a counting output stream to avoid having to close the underlying file and ask - // the file system for its size after each partition is written. - final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos); - boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { spillInputStreams[i] = new NioBufferedFileInputStream( - spills[i].file, - inputBufferSizeInBytes); + spills[i].file, + inputBufferSizeInBytes); } for (int partition = 0; partition < numPartitions; partition++) { - final long initialFileLength = mergedFileOutputStream.getByteCount(); - // Shield the underlying output stream from close() and flush() calls, so that we can close - // the higher level streams to make sure all data is really flushed and internal state is - // cleaned. - OutputStream partitionOutput = new CloseAndFlushShieldOutputStream( - new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); - partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); - if (compressionCodec != null) { - partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); - } - for (int i = 0; i < spills.length; i++) { - final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - if (partitionLengthInSpill > 0) { - InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], - partitionLengthInSpill, false); - try { - partitionInputStream = blockManager.serializerManager().wrapForEncryption( - partitionInputStream); - if (compressionCodec != null) { - partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + boolean copyThrewException = true; + ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); + OutputStream partitionOutput = writer.openStream(); + try { + partitionOutput = new TimeTrackingOutputStream(writeMetrics, partitionOutput); + partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); + if (compressionCodec != null) { + partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); + } + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = null; + boolean copySpillThrewException = true; + try { + partitionInputStream = new LimitedInputStream(spillInputStreams[i], + partitionLengthInSpill, false); + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream( + partitionInputStream); + } + ByteStreams.copy(partitionInputStream, partitionOutput); + copySpillThrewException = false; + } finally { + Closeables.close(partitionInputStream, copySpillThrewException); } - ByteStreams.copy(partitionInputStream, partitionOutput); - } finally { - partitionInputStream.close(); } } + copyThrewException = false; + } finally { + Closeables.close(partitionOutput, copyThrewException); } - partitionOutput.flush(); - partitionOutput.close(); - partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength); + long numBytesWritten = writer.getNumBytesWritten(); + writeMetrics.incBytesWritten(numBytesWritten); } threwException = false; } finally { @@ -420,9 +416,7 @@ private long[] mergeSpillsWithFileStream( for (InputStream stream : spillInputStreams) { Closeables.close(stream, threwException); } - Closeables.close(mergedFileOutputStream, threwException); } - return partitionLengths; } /** @@ -430,54 +424,46 @@ private long[] mergeSpillsWithFileStream( * This is only safe when the IO compression codec and serializer support concatenation of * serialized streams. * + * @param spills the spills to merge. + * @param mapWriter the map output writer to use for output. * @return the partition lengths in the merged file. */ - private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { - assert (spills.length >= 2); + private void mergeSpillsWithTransferTo( + SpillInfo[] spills, + ShuffleMapOutputWriter mapWriter) throws IOException { final int numPartitions = partitioner.numPartitions(); - final long[] partitionLengths = new long[numPartitions]; final FileChannel[] spillInputChannels = new FileChannel[spills.length]; final long[] spillInputChannelPositions = new long[spills.length]; - FileChannel mergedFileOutputChannel = null; boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); } - // This file needs to opened in append mode in order to work around a Linux kernel bug that - // affects transferTo; see SPARK-3948 for more details. - mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel(); - - long bytesWrittenToMergedFile = 0; for (int partition = 0; partition < numPartitions; partition++) { - for (int i = 0; i < spills.length; i++) { - final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - final FileChannel spillInputChannel = spillInputChannels[i]; - final long writeStartTime = System.nanoTime(); - Utils.copyFileStreamNIO( - spillInputChannel, - mergedFileOutputChannel, - spillInputChannelPositions[i], - partitionLengthInSpill); - spillInputChannelPositions[i] += partitionLengthInSpill; - writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); - bytesWrittenToMergedFile += partitionLengthInSpill; - partitionLengths[partition] += partitionLengthInSpill; + boolean copyThrewException = true; + ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); + WritableByteChannelWrapper resolvedChannel = writer.openChannelWrapper() + .orElseGet(() -> new StreamFallbackChannelWrapper(openStreamUnchecked(writer))); + try { + for (int i = 0; i < spills.length; i++) { + long partitionLengthInSpill = spills[i].partitionLengths[partition]; + final FileChannel spillInputChannel = spillInputChannels[i]; + final long writeStartTime = System.nanoTime(); + Utils.copyFileStreamNIO( + spillInputChannel, + resolvedChannel.channel(), + spillInputChannelPositions[i], + partitionLengthInSpill); + copyThrewException = false; + spillInputChannelPositions[i] += partitionLengthInSpill; + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); + } + } finally { + Closeables.close(resolvedChannel, copyThrewException); } - } - // Check the position after transferTo loop to see if it is in the right position and raise an - // exception if it is incorrect. The position will not be increased to the expected length - // after calling transferTo in kernel version 2.6.32. This issue is described at - // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. - if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) { - throw new IOException( - "Current position " + mergedFileOutputChannel.position() + " does not equal expected " + - "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" + - " version to see if it is 2.6.32, as there is a kernel bug which will lead to " + - "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " + - "to disable this NIO feature." - ); + long numBytes = writer.getNumBytesWritten(); + writeMetrics.incBytesWritten(numBytes); } threwException = false; } finally { @@ -487,9 +473,7 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th assert(spillInputChannelPositions[i] == spills[i].file.length()); Closeables.close(spillInputChannels[i], threwException); } - Closeables.close(mergedFileOutputChannel, threwException); } - return partitionLengths; } @Override @@ -518,4 +502,30 @@ public Option stop(boolean success) { } } } + + private static OutputStream openStreamUnchecked(ShufflePartitionWriter writer) { + try { + return writer.openStream(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static final class StreamFallbackChannelWrapper implements WritableByteChannelWrapper { + private final WritableByteChannel channel; + + StreamFallbackChannelWrapper(OutputStream fallbackStream) { + this.channel = Channels.newChannel(fallbackStream); + } + + @Override + public WritableByteChannel channel() { + return channel; + } + + @Override + public void close() throws IOException { + channel.close(); + } + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java index 02eb710737285..a0c7d3c248d48 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.sort.io; +import java.util.Optional; + import com.google.common.annotations.VisibleForTesting; import org.apache.spark.SparkConf; @@ -24,6 +26,7 @@ import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; import org.apache.spark.storage.BlockManager; public class LocalDiskShuffleExecutorComponents implements ShuffleExecutorComponents { @@ -58,8 +61,7 @@ public void initializeExecutor(String appId, String execId) { @Override public ShuffleMapOutputWriter createMapOutputWriter( int shuffleId, - int mapId, - long mapTaskAttemptId, + long mapId, int numPartitions) { if (blockResolver == null) { throw new IllegalStateException( @@ -68,4 +70,15 @@ public ShuffleMapOutputWriter createMapOutputWriter( return new LocalDiskShuffleMapOutputWriter( shuffleId, mapId, numPartitions, blockResolver, sparkConf); } + + @Override + public Optional createSingleFileMapOutputWriter( + int shuffleId, + long mapId) { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers."); + } + return Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver)); + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java index 7fc19b1270a46..a6529fd76188a 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -24,8 +24,8 @@ import java.io.OutputStream; import java.nio.channels.FileChannel; import java.nio.channels.WritableByteChannel; - import java.util.Optional; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -48,12 +48,13 @@ public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter { LoggerFactory.getLogger(LocalDiskShuffleMapOutputWriter.class); private final int shuffleId; - private final int mapId; + private final long mapId; private final IndexShuffleBlockResolver blockResolver; private final long[] partitionLengths; private final int bufferSize; private int lastPartitionId = -1; private long currChannelPosition; + private long bytesWrittenToMergedFile = 0L; private final File outputFile; private File outputTempFile; @@ -63,7 +64,7 @@ public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter { public LocalDiskShuffleMapOutputWriter( int shuffleId, - int mapId, + long mapId, int numPartitions, IndexShuffleBlockResolver blockResolver, SparkConf sparkConf) { @@ -97,6 +98,18 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I @Override public long[] commitAllPartitions() throws IOException { + // Check the position after transferTo loop to see if it is in the right position and raise a + // exception if it is incorrect. The position will not be increased to the expected length + // after calling transferTo in kernel version 2.6.32. This issue is described at + // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. + if (outputFileChannel != null && outputFileChannel.position() != bytesWrittenToMergedFile) { + throw new IOException( + "Current position " + outputFileChannel.position() + " does not equal expected " + + "position " + bytesWrittenToMergedFile + " after transferTo. Please check your " + + " kernel version to see if it is 2.6.32, as there is a kernel bug which will lead " + + "to unexpected behavior when using transferTo. You can set " + + "spark.file.transferTo=false to disable this NIO feature."); + } cleanUp(); File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); @@ -133,11 +146,10 @@ private void initStream() throws IOException { } private void initChannel() throws IOException { - if (outputFileStream == null) { - outputFileStream = new FileOutputStream(outputTempFile, true); - } + // This file needs to opened in append mode in order to work around a Linux kernel bug that + // affects transferTo; see SPARK-3948 for more details. if (outputFileChannel == null) { - outputFileChannel = outputFileStream.getChannel(); + outputFileChannel = new FileOutputStream(outputTempFile, true).getChannel(); } } @@ -227,6 +239,7 @@ public void write(byte[] buf, int pos, int length) throws IOException { public void close() { isClosed = true; partitionLengths[partitionId] = count; + bytesWrittenToMergedFile += count; } private void verifyNotClosed() { @@ -257,6 +270,7 @@ public WritableByteChannel channel() { @Override public void close() throws IOException { partitionLengths[partitionId] = getCount(); + bytesWrittenToMergedFile += partitionLengths[partitionId]; } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java new file mode 100644 index 0000000000000..c8b41992a8919 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; + +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; +import org.apache.spark.util.Utils; + +public class LocalDiskSingleSpillMapOutputWriter + implements SingleSpillShuffleMapOutputWriter { + + private final int shuffleId; + private final long mapId; + private final IndexShuffleBlockResolver blockResolver; + + public LocalDiskSingleSpillMapOutputWriter( + int shuffleId, + long mapId, + IndexShuffleBlockResolver blockResolver) { + this.shuffleId = shuffleId; + this.mapId = mapId; + this.blockResolver = blockResolver; + } + + @Override + public void transferMapSpillFile( + File mapSpillFile, + long[] partitionLengths) throws IOException { + // The map spill file already has the proper format, and it contains all of the partition data. + // So just transfer it directly to the destination without any merging. + File outputFile = blockResolver.getDataFile(shuffleId, mapId); + File tempFile = Utils.tempFileWith(outputFile); + Files.move(mapSpillFile.toPath(), tempFile.toPath()); + blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tempFile); + } +} diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index d320ba3139541..b15365fe54ad6 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -886,6 +886,7 @@ public void reset() { numKeys = 0; numValues = 0; freeArray(longArray); + longArray = null; while (dataPages.size() > 0) { MemoryBlock dataPage = dataPages.removeLast(); freePage(dataPage); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 1b206c11d9a8e..55e4e609c3c7b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -447,8 +447,6 @@ public void insertKVRecord(Object keyBase, long keyOffset, int keyLen, /** * Merges another UnsafeExternalSorters into this one, the other one will be emptied. - * - * @throws IOException */ public void merge(UnsafeExternalSorter other) throws IOException { other.spill(); diff --git a/core/src/main/resources/org/apache/spark/ui/static/stagepage.js b/core/src/main/resources/org/apache/spark/ui/static/stagepage.js index 3ef1a76fd7202..b28c981da20a5 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/stagepage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/stagepage.js @@ -286,7 +286,7 @@ $(document).ready(function () { " Show Additional Metrics" + "" + " - You may also provide a "message handler function" that takes a Kinesis `Record` and returns a generic object `T`, in case you would like to use other data included in a `Record` such as partition key. This is currently only supported in Scala and Java. + You may also provide the following settings. These are currently only supported in Scala and Java. + + - A "message handler function" that takes a Kinesis `Record` and returns a generic object `T`, in case you would like to use other data included in a `Record` such as partition key. + + - CloudWatch metrics level and dimensions. See [the AWS documentation about monitoring KCL](https://docs.aws.amazon.com/streams/latest/dev/monitoring-with-kcl.html) for details.
+ import collection.JavaConverters._ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.kinesis.KinesisInputDStream import org.apache.spark.streaming.{Seconds, StreamingContext} import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration + import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel val kinesisStream = KinesisInputDStream.builder .streamingContext(streamingContext) @@ -116,17 +123,22 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m .checkpointAppName([Kinesis app name]) .checkpointInterval([checkpoint interval]) .storageLevel(StorageLevel.MEMORY_AND_DISK_2) + .metricsLevel(MetricsLevel.DETAILED) + .metricsEnabledDimensions(KinesisClientLibConfiguration.DEFAULT_METRICS_ENABLED_DIMENSIONS.asScala.toSet) .buildWithMessageHandler([message handler])
- import org.apache.spark.storage.StorageLevel - import org.apache.spark.streaming.kinesis.KinesisInputDStream - import org.apache.spark.streaming.Seconds - import org.apache.spark.streaming.StreamingContext - import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream - - KinesisInputDStream kinesisStream = KinesisInputDStream.builder + import org.apache.spark.storage.StorageLevel; + import org.apache.spark.streaming.kinesis.KinesisInputDStream; + import org.apache.spark.streaming.Seconds; + import org.apache.spark.streaming.StreamingContext; + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration; + import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel; + import scala.collection.JavaConverters; + + KinesisInputDStream kinesisStream = KinesisInputDStream.builder() .streamingContext(streamingContext) .endpointUrl([endpoint URL]) .regionName([region name]) @@ -135,6 +147,8 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m .checkpointAppName([Kinesis app name]) .checkpointInterval([checkpoint interval]) .storageLevel(StorageLevel.MEMORY_AND_DISK_2) + .metricsLevel(MetricsLevel.DETAILED) + .metricsEnabledDimensions(JavaConverters.asScalaSetConverter(KinesisClientLibConfiguration.DEFAULT_METRICS_ENABLED_DIMENSIONS).asScala().toSet()) .buildWithMessageHandler([message handler]);
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index f5abed74bff20..f6b579fbf74d1 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2488,13 +2488,13 @@ additional effort may be necessary to achieve exactly-once semantics. There are * [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) and [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream) * [KafkaUtils](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$), - [KinesisUtils](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$), + [KinesisUtils](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisInputDStream), - Java docs * [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html), [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) and [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html) * [KafkaUtils](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html), - [KinesisUtils](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) + [KinesisUtils](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisInputDStream.html) - Python docs * [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) and [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) * [KafkaUtils](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index c4378b4a02663..89732d309aa27 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -27,6 +27,8 @@ For Scala/Java applications using SBT/Maven project definitions, link your appli artifactId = spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}} version = {{site.SPARK_VERSION_SHORT}} +Please note that to use the headers functionality, your Kafka client version should be version 0.11.0.0 or up. + For Python applications, you need to add this above library and its dependencies when deploying your application. See the [Deploying](#deploying) subsection below. @@ -50,6 +52,17 @@ val df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] +// Subscribe to 1 topic, with headers +val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .option("includeHeaders", "true") + .load() +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers") + .as[(String, String, Map)] + // Subscribe to multiple topics val df = spark .readStream @@ -84,6 +97,16 @@ Dataset df = spark .load(); df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); +// Subscribe to 1 topic, with headers +Dataset df = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .option("includeHeaders", "true") + .load() +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers"); + // Subscribe to multiple topics Dataset df = spark .readStream() @@ -116,6 +139,16 @@ df = spark \ .load() df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to 1 topic, with headers +val df = spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("subscribe", "topic1") \ + .option("includeHeaders", "true") \ + .load() +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers") + # Subscribe to multiple topics df = spark \ .readStream \ @@ -286,6 +319,10 @@ Each row in the source has the following schema: timestampType int + + headers (optional) + array + The following options must be set for the Kafka source @@ -325,6 +362,27 @@ The following configurations are optional: + + + + + + + + + + + + + + + + + + + + +
Optionvaluedefaultquery typemeaning
startingOffsetsByTimestampjson string + """ {"topicA":{"0": 1000, "1": 1000}, "topicB": {"0": 2000, "1": 2000}} """ + none (the value of startingOffsets will apply)streaming and batchThe start point of timestamp when a query is started, a json string specifying a starting timestamp for + each TopicPartition. The returned offset for each partition is the earliest offset whose timestamp is greater than or + equal to the given timestamp in the corresponding partition. If the matched offset doesn't exist, + the query will fail immediately to prevent unintended read from such partition. (This is a kind of limitation as of now, and will be addressed in near future.)

+

+ Spark simply passes the timestamp information to KafkaConsumer.offsetsForTimes, and doesn't interpret or reason about the value.

+ For more details on KafkaConsumer.offsetsForTimes, please refer javadoc for details.

+ Also the meaning of timestamp here can be vary according to Kafka configuration (log.message.timestamp.type): please refer Kafka documentation for further details.

+ Note: This option requires Kafka 0.10.1.0 or higher.

+ Note2: startingOffsetsByTimestamp takes precedence over startingOffsets.

+ Note3: For streaming queries, this only applies when a new query is started, and that resuming will + always pick up from where the query left off. Newly discovered partitions during a query will start at + earliest.

startingOffsets "earliest", "latest" (streaming only), or json string @@ -340,6 +398,25 @@ The following configurations are optional: always pick up from where the query left off. Newly discovered partitions during a query will start at earliest.
endingOffsetsByTimestampjson string + """ {"topicA":{"0": 1000, "1": 1000}, "topicB": {"0": 2000, "1": 2000}} """ + latestbatch queryThe end point when a batch query is ended, a json string specifying an ending timesamp for each TopicPartition. + The returned offset for each partition is the earliest offset whose timestamp is greater than or equal to + the given timestamp in the corresponding partition. If the matched offset doesn't exist, the offset will + be set to latest.

+

+ Spark simply passes the timestamp information to KafkaConsumer.offsetsForTimes, and doesn't interpret or reason about the value.

+ For more details on KafkaConsumer.offsetsForTimes, please refer javadoc for details.

+ Also the meaning of timestamp here can be vary according to Kafka configuration (log.message.timestamp.type): please refer Kafka documentation for further details.

+ Note: This option requires Kafka 0.10.1.0 or higher.

+ Note2: endingOffsetsByTimestamp takes precedence over endingOffsets. +

endingOffsets latest or json string @@ -425,6 +502,13 @@ The following configurations are optional: issues, set the Kafka consumer session timeout (by setting option "kafka.session.timeout.ms") to be very small. When this is set, option "groupIdPrefix" will be ignored.
includeHeadersbooleanfalsestreaming and batchWhether to include the Kafka headers in the row.
### Consumer Caching @@ -522,6 +606,10 @@ The Dataframe being written to Kafka should have the following columns in schema value (required) string or binary + + headers (optional) + array + topic (*optional) string @@ -559,6 +647,13 @@ The following configurations are optional: Sets the topic that all rows will be written to in Kafka. This option overrides any topic column that may exist in the data. + + includeHeaders + boolean + false + streaming and batch + Whether to include the Kafka headers in the row. + ### Creating a Kafka Sink for Streaming Queries @@ -825,7 +920,9 @@ Delegation tokens can be obtained from multiple clusters and ${cluster}spark.kafka.clusters.${cluster}.security.protocol SASL_SSL - Protocol used to communicate with brokers. For further details please see Kafka documentation. Only used to obtain delegation token. + Protocol used to communicate with brokers. For further details please see Kafka documentation. Protocol is applied on all the sources and sinks as default where + bootstrap.servers config matches (for further details please see spark.kafka.clusters.${cluster}.target.bootstrap.servers.regex), + and can be overridden by setting kafka.security.protocol on the source or sink. diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index deaf262c5f572..2a405f36fd5fd 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1505,7 +1505,6 @@ Additional details on supported joins: - Cannot use mapGroupsWithState and flatMapGroupsWithState in Update mode before joins. - ### Streaming Deduplication You can deduplicate records in data streams using a unique identifier in the events. This is exactly same as deduplication on static using a unique identifier column. The query will store the necessary amount of data from previous records such that it can filter duplicate records. Similar to aggregations, you can use deduplication with or without watermarking. @@ -1616,6 +1615,8 @@ this configuration judiciously. ### Arbitrary Stateful Operations Many usecases require more advanced stateful operations than aggregations. For example, in many usecases, you have to track sessions from data streams of events. For doing such sessionization, you will have to save arbitrary types of data as state, and perform arbitrary operations on the state using the data stream events in every trigger. Since Spark 2.2, this can be done using the operation `mapGroupsWithState` and the more powerful operation `flatMapGroupsWithState`. Both operations allow you to apply user-defined code on grouped Datasets to update user-defined state. For more concrete details, take a look at the API documentation ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.GroupState)/[Java](api/java/org/apache/spark/sql/streaming/GroupState.html)) and the examples ([Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java)). +Though Spark cannot check and force it, the state function should be implemented with respect to the semantics of the output mode. For example, in Update mode Spark doesn't expect that the state function will emit rows which are older than current watermark plus allowed late record delay, whereas in Append mode the state function can emit these rows. + ### Unsupported Operations There are a few DataFrame/Dataset operations that are not supported with streaming DataFrames/Datasets. Some of them are as follows. @@ -1647,6 +1648,26 @@ For example, sorting on the input stream is not supported, as it requires keepin track of all the data received in the stream. This is therefore fundamentally hard to execute efficiently. +### Limitation of global watermark + +In Append mode, if a stateful operation emits rows older than current watermark plus allowed late record delay, +they will be "late rows" in downstream stateful operations (as Spark uses global watermark). Note that these rows may be discarded. +This is a limitation of a global watermark, and it could potentially cause a correctness issue. + +Spark will check the logical plan of query and log a warning when Spark detects such a pattern. + +Any of the stateful operation(s) after any of below stateful operations can have this issue: + +* streaming aggregation in Append mode +* stream-stream outer join +* `mapGroupsWithState` and `flatMapGroupsWithState` in Append mode (depending on the implementation of the state function) + +As Spark cannot check the state function of `mapGroupsWithState`/`flatMapGroupsWithState`, Spark assumes that the state function +emits late rows if the operator uses Append mode. + +There's a known workaround: split your streaming query into multiple queries per stateful operator, and ensure +end-to-end exactly once per query. Ensuring end-to-end exactly once for the last query is optional. + ## Starting Streaming Queries Once you have defined the final result DataFrame/Dataset, all that is left is for you to start the streaming computation. To do that, you have to use the `DataStreamWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamWriter)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamWriter.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamWriter) docs) diff --git a/docs/web-ui.md b/docs/web-ui.md index 72423d9468e83..e6025370e6796 100644 --- a/docs/web-ui.md +++ b/docs/web-ui.md @@ -404,3 +404,44 @@ The web UI includes a Streaming tab if the application uses Spark streaming. Thi scheduling delay and processing time for each micro-batch in the data stream, which can be useful for troubleshooting the streaming application. +## JDBC/ODBC Server Tab +We can see this tab when Spark is running as a [distributed SQL engine](sql-distributed-sql-engine.html). It shows information about sessions and submitted SQL operations. + +The first section of the page displays general information about the JDBC/ODBC server: start time and uptime. + +

+ JDBC/ODBC Header +

+ +The second section contains information about active and finished sessions. +* **User** and **IP** of the connection. +* **Session id** link to access to session info. +* **Start time**, **finish time** and **duration** of the session. +* **Total execute** is the number of operations submitted in this session. + +

+ JDBC/ODBC sessions +

+ +The third section has the SQL statistics of the submitted operations. +* **User** that submit the operation. +* **Job id** link to [jobs tab](web-ui.html#jobs-tab). +* **Group id** of the query that group all jobs together. An application can cancel all running jobs using this group id. +* **Start time** of the operation. +* **Finish time** of the execution, before fetching the results. +* **Close time** of the operation after fetching the results. +* **Execution time** is the difference between finish time and start time. +* **Duration time** is the difference between close time and start time. +* **Statement** is the operation being executed. +* **State** of the process. + * _Started_, first state, when the process begins. + * _Compiled_, execution plan generated. + * _Failed_, final state when the execution failed or finished with error. + * _Canceled_, final state when the execution is canceled. + * _Finished_ processing and waiting to fetch results. + * _Closed_, final state when client closed the statement. +* **Detail** of the execution plan with parsed logical plan, analyzed logical plan, optimized logical plan and physical plan or errors in the the SQL statement. + +

+ JDBC/ODBC SQL Statistics +

diff --git a/examples/pom.xml b/examples/pom.xml index ac148ef4c9c01..a099f1e042e99 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -107,7 +107,7 @@ com.github.scopt scopt_${scala.binary.version} - 3.7.0 + 3.7.1 ${hive.parquet.group} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java deleted file mode 100644 index 324a781c1a44a..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; - -// $example on$ -import scala.Tuple2; - -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.regression.LinearRegressionModel; -import org.apache.spark.mllib.regression.LinearRegressionWithSGD; -// $example off$ - -/** - * Example for LinearRegressionWithSGD. - */ -public class JavaLinearRegressionWithSGDExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaLinearRegressionWithSGDExample"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // $example on$ - // Load and parse the data - String path = "data/mllib/ridge-data/lpsa.data"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map(line -> { - String[] parts = line.split(","); - String[] features = parts[1].split(" "); - double[] v = new double[features.length]; - for (int i = 0; i < features.length - 1; i++) { - v[i] = Double.parseDouble(features[i]); - } - return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); - }); - parsedData.cache(); - - // Building the model - int numIterations = 100; - double stepSize = 0.00000001; - LinearRegressionModel model = - LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations, stepSize); - - // Evaluate model on training examples and compute training error - JavaPairRDD valuesAndPreds = parsedData.mapToPair(point -> - new Tuple2<>(model.predict(point.features()), point.label())); - - double MSE = valuesAndPreds.mapToDouble(pair -> { - double diff = pair._1() - pair._2(); - return diff * diff; - }).mean(); - System.out.println("training Mean Squared Error = " + MSE); - - // Save and load model - model.save(sc.sc(), "target/tmp/javaLinearRegressionWithSGDModel"); - LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), - "target/tmp/javaLinearRegressionWithSGDModel"); - // $example off$ - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java deleted file mode 100644 index 00033b5730a3d..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -// $example on$ -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.regression.LinearRegressionModel; -import org.apache.spark.mllib.regression.LinearRegressionWithSGD; -import org.apache.spark.mllib.evaluation.RegressionMetrics; -import org.apache.spark.SparkConf; -// $example off$ - -public class JavaRegressionMetricsExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Java Regression Metrics Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - // $example on$ - // Load and parse the data - String path = "data/mllib/sample_linear_regression_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map(line -> { - String[] parts = line.split(" "); - double[] v = new double[parts.length - 1]; - for (int i = 1; i < parts.length; i++) { - v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); - } - return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); - }); - parsedData.cache(); - - // Building the model - int numIterations = 100; - LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), - numIterations); - - // Evaluate model on training examples and compute training error - JavaPairRDD valuesAndPreds = parsedData.mapToPair(point -> - new Tuple2<>(model.predict(point.features()), point.label())); - - // Instantiate metrics object - RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); - - // Squared error - System.out.format("MSE = %f\n", metrics.meanSquaredError()); - System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); - - // R-squared - System.out.format("R Squared = %f\n", metrics.r2()); - - // Mean absolute error - System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); - - // Explained variance - System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); - - // Save and load model - model.save(sc.sc(), "target/tmp/LogisticRegressionModel"); - LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), - "target/tmp/LogisticRegressionModel"); - // $example off$ - - sc.stop(); - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/AccumulatorMetricsTest.scala b/examples/src/main/scala/org/apache/spark/examples/AccumulatorMetricsTest.scala index 5d9a9a73f12ec..36da10568989d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/AccumulatorMetricsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/AccumulatorMetricsTest.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.SparkSession * accumulator source) are reported to stdout as well. */ object AccumulatorMetricsTest { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder() diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 3311de12dbd97..d7e79966037cc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.SparkSession * Usage: BroadcastTest [partitions] [numElem] [blockSize] */ object BroadcastTest { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val blockSize = if (args.length > 2) args(2) else "4096" diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index d12ef642bd2cd..ed56108f4b624 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.Utils * test driver submission in the standalone scheduler. */ object DriverSubmissionTest { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 1) { println("Usage: DriverSubmissionTest ") System.exit(0) diff --git a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala index 45c4953a84be2..6e95318a8cbc0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples import org.apache.spark.sql.SparkSession object ExceptionHandlingTest { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("ExceptionHandlingTest") diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index 2f2bbb1275438..c07c1afbcb174 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.SparkSession * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] */ object GroupByTest { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("GroupBy Test") diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala index b327e13533b81..48698678571e3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.SparkSession object HdfsTest { /** Usage: HdfsTest [file] */ - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 1) { System.err.println("Usage: HdfsTest ") System.exit(1) diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index 3f9cea35d6503..87c2f6853807a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -93,7 +93,7 @@ object LocalALS { new CholeskyDecomposition(XtX).getSolver.solve(Xty) } - def showWarning() { + def showWarning(): Unit = { System.err.println( """WARN: This is a naive implementation of ALS and is given as an example! |Please use org.apache.spark.ml.recommendation.ALS @@ -101,7 +101,7 @@ object LocalALS { """.stripMargin) } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { args match { case Array(m, u, f, iters) => diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index 5512e33e41ac3..5478c585a959e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -39,7 +39,7 @@ object LocalFileLR { DataPoint(new DenseVector(nums.slice(1, D + 1)), nums(0)) } - def showWarning() { + def showWarning(): Unit = { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! |Please use org.apache.spark.ml.classification.LogisticRegression @@ -47,7 +47,7 @@ object LocalFileLR { """.stripMargin) } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { showWarning() diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index f5162a59522f0..4a73466841f69 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -62,7 +62,7 @@ object LocalKMeans { bestIndex } - def showWarning() { + def showWarning(): Unit = { System.err.println( """WARN: This is a naive implementation of KMeans Clustering and is given as an example! |Please use org.apache.spark.ml.clustering.KMeans @@ -70,7 +70,7 @@ object LocalKMeans { """.stripMargin) } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { showWarning() diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index bde8ccd305960..4ca0ecdcfe6e0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -46,7 +46,7 @@ object LocalLR { Array.tabulate(N)(generatePoint) } - def showWarning() { + def showWarning(): Unit = { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! |Please use org.apache.spark.ml.classification.LogisticRegression @@ -54,7 +54,7 @@ object LocalLR { """.stripMargin) } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { showWarning() diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala index a93c15c85cfc1..7660ffd02ed9b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala @@ -21,7 +21,7 @@ package org.apache.spark.examples import scala.math.random object LocalPi { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { var count = 0 for (i <- 1 to 100000) { val x = random * 2 - 1 diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 03187aee044e4..e2120eaee6e5a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -41,7 +41,7 @@ object LogQuery { | 0 73.23.2.15 images.com 1358492557 - Whatup""".stripMargin.split('\n').mkString ) - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val sparkConf = new SparkConf().setAppName("Log Query") val sc = new SparkContext(sparkConf) diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index e6f33b7adf5d1..4bea5cae775cb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.SparkSession * Usage: MultiBroadcastTest [partitions] [numElem] */ object MultiBroadcastTest { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index 2332a661f26a0..2bd7c3e954396 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.SparkSession * Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio] */ object SimpleSkewedGroupByTest { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("SimpleSkewedGroupByTest") diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 4d3c34041bc17..2e7abd62dcdc6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.SparkSession * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] */ object SkewedGroupByTest { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("GroupBy Test") diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index d3e7b7a967de7..651f0224d4402 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -78,7 +78,7 @@ object SparkALS { new CholeskyDecomposition(XtX).getSolver.solve(Xty) } - def showWarning() { + def showWarning(): Unit = { System.err.println( """WARN: This is a naive implementation of ALS and is given as an example! |Please use org.apache.spark.ml.recommendation.ALS @@ -86,7 +86,7 @@ object SparkALS { """.stripMargin) } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { var slices = 0 diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 23eaa879114a9..8c09ce614d931 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -49,7 +49,7 @@ object SparkHdfsLR { DataPoint(new DenseVector(x), y) } - def showWarning() { + def showWarning(): Unit = { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! |Please use org.apache.spark.ml.classification.LogisticRegression @@ -57,7 +57,7 @@ object SparkHdfsLR { """.stripMargin) } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 2) { System.err.println("Usage: SparkHdfsLR ") diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index b005cb6971c16..ec9b44ce6e3b7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -49,7 +49,7 @@ object SparkKMeans { bestIndex } - def showWarning() { + def showWarning(): Unit = { System.err.println( """WARN: This is a naive implementation of KMeans Clustering and is given as an example! |Please use org.apache.spark.ml.clustering.KMeans @@ -57,7 +57,7 @@ object SparkKMeans { """.stripMargin) } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 3) { System.err.println("Usage: SparkKMeans ") diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 4b1497345af82..deb6668f7ecfc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -51,7 +51,7 @@ object SparkLR { Array.tabulate(N)(generatePoint) } - def showWarning() { + def showWarning(): Unit = { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! |Please use org.apache.spark.ml.classification.LogisticRegression @@ -59,7 +59,7 @@ object SparkLR { """.stripMargin) } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { showWarning() diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 9299bad5d3290..3bd475c440d72 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.SparkSession */ object SparkPageRank { - def showWarning() { + def showWarning(): Unit = { System.err.println( """WARN: This is a naive implementation of PageRank and is given as an example! |Please use the PageRank implementation found in org.apache.spark.graphx.lib.PageRank @@ -47,7 +47,7 @@ object SparkPageRank { """.stripMargin) } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 1) { System.err.println("Usage: SparkPageRank ") System.exit(1) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 828d98b5001d7..a8eec6a99cf4b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.SparkSession /** Computes an approximation to pi */ object SparkPi { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("Spark Pi") diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala b/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala index 64076f2deb706..99a12b9442365 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.SparkSession /** Usage: SparkRemoteFileTest [file] */ object SparkRemoteFileTest { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 1) { System.err.println("Usage: SparkRemoteFileTest ") System.exit(1) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index f5d42141f5dd2..7a6fa9a797ff9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -41,7 +41,7 @@ object SparkTC { edges.toSeq } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("SparkTC") diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala index da3ffca1a6f2a..af18c0afbb223 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala @@ -23,7 +23,7 @@ package org.apache.spark.examples.graphx * http://snap.stanford.edu/data/soc-LiveJournal1.html. */ object LiveJournalPageRank { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 1) { System.err.println( "Usage: LiveJournalPageRank \n" + diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 57b2edf992208..8bc9c0a86eab6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -47,7 +47,7 @@ object SynthBenchmark { * -degFile the local file to save the degree information (Default: Empty) * -seed seed to use for RNGs (Default: -1, picks seed randomly) */ - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val options = args.map { arg => arg.dropWhile(_ == '-').split('=') match { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala index 8091838a2301e..354e65c2bae38 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala @@ -42,7 +42,7 @@ object ALSExample { } // $example off$ - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("ALSExample") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala index 5638e66b8792a..1a67a6e755ab4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.sql.SparkSession object ChiSqSelectorExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("ChiSqSelectorExample") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala index 91d861dd4380a..947ca5f5fb5e1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel} import org.apache.spark.sql.SparkSession object CountVectorizerExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("CountVectorizerExample") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index ee4469faab3a0..4377efd9e95fa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -41,7 +41,7 @@ object DataFrameExample { case class Params(input: String = "data/mllib/sample_libsvm_data.txt") extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("DataFrameExample") { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 19f2d7751bc54..ef38163d7eb0d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -65,7 +65,7 @@ object DecisionTreeExample { checkpointDir: Option[String] = None, checkpointInterval: Int = 10) extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("DecisionTreeExample") { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 2dc11b07d88ef..9b5dfed0cb31b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.{Dataset, Row, SparkSession} */ object DeveloperApiExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("DeveloperApiExample") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index 8f3ce4b315bd3..ca4235d53e636 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -63,7 +63,7 @@ object GBTExample { checkpointDir: Option[String] = None, checkpointInterval: Int = 10) extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("GBTExample") { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala index 2940682c32801..b3642c0b45db6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.feature.{IndexToString, StringIndexer} import org.apache.spark.sql.SparkSession object IndexToStringExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("IndexToStringExample") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala index 6903a1c298ced..370c6fd7c17fc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -50,7 +50,7 @@ object LinearRegressionExample { tol: Double = 1E-6, fracTest: Double = 0.2) extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("LinearRegressionExample") { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala index bd6cc8cff2348..b64ab4792add4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -55,7 +55,7 @@ object LogisticRegressionExample { tol: Double = 1E-6, fracTest: Double = 0.2) extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("LogisticRegressionExample") { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index 4ad6c7c3ef202..86e70e8ab0189 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.SparkSession */ object OneVsRestExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName(s"OneVsRestExample") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala index 0fe16fb6dfa9f..55823fe1832e5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.feature.QuantileDiscretizer import org.apache.spark.sql.SparkSession object QuantileDiscretizerExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("QuantileDiscretizerExample") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index 3c127a46e1f10..6ba14bcd1822f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -64,7 +64,7 @@ object RandomForestExample { checkpointDir: Option[String] = None, checkpointInterval: Int = 10) extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("RandomForestExample") { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala index bb4587b82cb37..bf6a4846b6e34 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.feature.SQLTransformer import org.apache.spark.sql.SparkSession object SQLTransformerExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("SQLTransformerExample") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala index ec2df2ef876ba..6121c81cd1f5d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.SparkSession object TfIdfExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("TfIdfExample") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala index b4179ecc1e56d..05f2ee3288624 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala @@ -82,7 +82,7 @@ object UnaryTransformerExample { object MyTransformer extends DefaultParamsReadable[MyTransformer] // $example off$ - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder() .appName("UnaryTransformerExample") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala index 4bcc6ac6a01f5..8ff0e8c6a51c8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.SparkSession object Word2VecExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("Word2Vec example") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala index a07535bb5a38d..1a7839414b38e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala @@ -26,7 +26,7 @@ import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset object AssociationRulesExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("AssociationRulesExample") val sc = new SparkContext(conf) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index e3cc1d9c83361..6fc3501fc57b5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -58,7 +58,7 @@ object BinaryClassification { regType: RegType = L2, regParam: Double = 0.01) extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("BinaryClassification") { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala index 53d0b8fc208ef..b7f0ba00f913e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala @@ -34,7 +34,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} */ object BisectingKMeansExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val sparkConf = new SparkConf().setAppName("mllib.BisectingKMeansExample") val sc = new SparkContext(sparkConf) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala index 0b44c339ef139..cf9f7adbf6999 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -37,7 +37,7 @@ object Correlations { case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala index 681465d2176d4..9082f0b5a8b85 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -45,7 +45,7 @@ object CosineSimilarity { case class Params(inputFile: String = null, threshold: Double = 0.1) extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("CosineSimilarity") { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index b5d1b02f92524..1029ca04c348f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -67,7 +67,7 @@ object DecisionTreeRunner { checkpointDir: Option[String] = None, checkpointInterval: Int = 10) extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("DecisionTreeRunner") { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala index b228827e5886f..0259df2799174 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -47,7 +47,7 @@ object DenseKMeans { numIterations: Int = 10, initializationMode: InitializationMode = Parallel) extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("DenseKMeans") { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index f724ee1030f04..a25ce826ee842 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -35,7 +35,7 @@ object FPGrowthExample { minSupport: Double = 0.3, numPartition: Int = -1) extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("FPGrowthExample") { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GaussianMixtureExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GaussianMixtureExample.scala index b1b3a79d87ae1..103d212a80e78 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GaussianMixtureExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GaussianMixtureExample.scala @@ -26,7 +26,7 @@ import org.apache.spark.mllib.linalg.Vectors object GaussianMixtureExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("GaussianMixtureExample") val sc = new SparkContext(conf) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index 3f264933cd3cc..12e0c8df274b2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -50,7 +50,7 @@ object GradientBoostedTreesRunner { numIterations: Int = 10, fracTest: Double = 0.2) extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("GradientBoostedTrees") { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala index 9b3c3266ee30a..8435209377553 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala @@ -29,7 +29,7 @@ import org.apache.spark.rdd.RDD object HypothesisTestingExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("HypothesisTestingExample") val sc = new SparkContext(conf) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala index b0a6f1671a898..17ebd4159b8d7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala @@ -26,7 +26,7 @@ import org.apache.spark.mllib.linalg.Vectors object KMeansExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("KMeansExample") val sc = new SparkContext(conf) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index cd77ecf990b3b..605ca68e627ec 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -53,7 +53,7 @@ object LDAExample { checkpointDir: Option[String] = None, checkpointInterval: Int = 10) extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("LDAExample") { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala index d25962c5500ed..55a45b302b5a3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala @@ -26,7 +26,7 @@ import org.apache.spark.mllib.linalg.Vectors object LatentDirichletAllocationExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("LatentDirichletAllocationExample") val sc = new SparkContext(conf) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala deleted file mode 100644 index 03222b13ad27d..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.mllib - -import org.apache.log4j.{Level, Logger} -import scopt.OptionParser - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.mllib.optimization.{L1Updater, SimpleUpdater, SquaredL2Updater} -import org.apache.spark.mllib.regression.LinearRegressionWithSGD -import org.apache.spark.mllib.util.MLUtils - -/** - * An example app for linear regression. Run with - * {{{ - * bin/run-example org.apache.spark.examples.mllib.LinearRegression - * }}} - * A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt`. - * If you use it as a template to create your own app, please use `spark-submit` to submit your app. - */ -@deprecated("Use ml.regression.LinearRegression or LBFGS", "2.0.0") -object LinearRegression { - - object RegType extends Enumeration { - type RegType = Value - val NONE, L1, L2 = Value - } - - import RegType._ - - case class Params( - input: String = null, - numIterations: Int = 100, - stepSize: Double = 1.0, - regType: RegType = L2, - regParam: Double = 0.01) extends AbstractParams[Params] - - def main(args: Array[String]) { - val defaultParams = Params() - - val parser = new OptionParser[Params]("LinearRegression") { - head("LinearRegression: an example app for linear regression.") - opt[Int]("numIterations") - .text("number of iterations") - .action((x, c) => c.copy(numIterations = x)) - opt[Double]("stepSize") - .text(s"initial step size, default: ${defaultParams.stepSize}") - .action((x, c) => c.copy(stepSize = x)) - opt[String]("regType") - .text(s"regularization type (${RegType.values.mkString(",")}), " + - s"default: ${defaultParams.regType}") - .action((x, c) => c.copy(regType = RegType.withName(x))) - opt[Double]("regParam") - .text(s"regularization parameter, default: ${defaultParams.regParam}") - arg[String]("") - .required() - .text("input paths to labeled examples in LIBSVM format") - .action((x, c) => c.copy(input = x)) - note( - """ - |For example, the following command runs this app on a synthetic dataset: - | - | bin/spark-submit --class org.apache.spark.examples.mllib.LinearRegression \ - | examples/target/scala-*/spark-examples-*.jar \ - | data/mllib/sample_linear_regression_data.txt - """.stripMargin) - } - - parser.parse(args, defaultParams) match { - case Some(params) => run(params) - case _ => sys.exit(1) - } - } - - def run(params: Params): Unit = { - val conf = new SparkConf().setAppName(s"LinearRegression with $params") - val sc = new SparkContext(conf) - - Logger.getRootLogger.setLevel(Level.WARN) - - val examples = MLUtils.loadLibSVMFile(sc, params.input).cache() - - val splits = examples.randomSplit(Array(0.8, 0.2)) - val training = splits(0).cache() - val test = splits(1).cache() - - val numTraining = training.count() - val numTest = test.count() - println(s"Training: $numTraining, test: $numTest.") - - examples.unpersist() - - val updater = params.regType match { - case NONE => new SimpleUpdater() - case L1 => new L1Updater() - case L2 => new SquaredL2Updater() - } - - val algorithm = new LinearRegressionWithSGD() - algorithm.optimizer - .setNumIterations(params.numIterations) - .setStepSize(params.stepSize) - .setUpdater(updater) - .setRegParam(params.regParam) - - val model = algorithm.run(training) - - val prediction = model.predict(test.map(_.features)) - val predictionAndLabel = prediction.zip(test.map(_.label)) - - val loss = predictionAndLabel.map { case (p, l) => - val err = p - l - err * err - }.reduce(_ + _) - val rmse = math.sqrt(loss / numTest) - - println(s"Test RMSE = $rmse.") - - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala deleted file mode 100644 index 449b725d1d173..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.mllib - -import org.apache.spark.{SparkConf, SparkContext} -// $example on$ -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.regression.LinearRegressionModel -import org.apache.spark.mllib.regression.LinearRegressionWithSGD -// $example off$ - -@deprecated("Use ml.regression.LinearRegression or LBFGS", "2.0.0") -object LinearRegressionWithSGDExample { - - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("LinearRegressionWithSGDExample") - val sc = new SparkContext(conf) - - // $example on$ - // Load and parse the data - val data = sc.textFile("data/mllib/ridge-data/lpsa.data") - val parsedData = data.map { line => - val parts = line.split(',') - LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) - }.cache() - - // Building the model - val numIterations = 100 - val stepSize = 0.00000001 - val model = LinearRegressionWithSGD.train(parsedData, numIterations, stepSize) - - // Evaluate model on training examples and compute training error - val valuesAndPreds = parsedData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) - } - val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2) }.mean() - println(s"training Mean Squared Error $MSE") - - // Save and load model - model.save(sc, "target/tmp/scalaLinearRegressionWithSGDModel") - val sameModel = LinearRegressionModel.load(sc, "target/tmp/scalaLinearRegressionWithSGDModel") - // $example off$ - - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index fd810155d6a88..92c85c9271a5a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -48,7 +48,7 @@ object MovieLensALS { numProductBlocks: Int = -1, implicitPrefs: Boolean = false) extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("MovieLensALS") { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala index f9e47e485e72f..b5c52f9a31224 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -38,7 +38,7 @@ object MultivariateSummarizer { case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala deleted file mode 100644 index eff2393cc3abe..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.mllib - -import org.apache.spark.SparkConf -import org.apache.spark.SparkContext -// $example on$ -import org.apache.spark.mllib.feature.PCA -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD} -// $example off$ - -@deprecated("Deprecated since LinearRegressionWithSGD is deprecated. Use ml.feature.PCA", "2.0.0") -object PCAExample { - - def main(args: Array[String]): Unit = { - - val conf = new SparkConf().setAppName("PCAExample") - val sc = new SparkContext(conf) - - // $example on$ - val data = sc.textFile("data/mllib/ridge-data/lpsa.data").map { line => - val parts = line.split(',') - LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) - }.cache() - - val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) - val training = splits(0).cache() - val test = splits(1) - - val pca = new PCA(training.first().features.size / 2).fit(data.map(_.features)) - val training_pca = training.map(p => p.copy(features = pca.transform(p.features))) - val test_pca = test.map(p => p.copy(features = pca.transform(p.features))) - - val numIterations = 100 - val model = LinearRegressionWithSGD.train(training, numIterations) - val model_pca = LinearRegressionWithSGD.train(training_pca, numIterations) - - val valuesAndPreds = test.map { point => - val score = model.predict(point.features) - (score, point.label) - } - - val valuesAndPreds_pca = test_pca.map { point => - val score = model_pca.predict(point.features) - (score, point.label) - } - - val MSE = valuesAndPreds.map { case (v, p) => math.pow((v - p), 2) }.mean() - val MSE_pca = valuesAndPreds_pca.map { case (v, p) => math.pow((v - p), 2) }.mean() - - println(s"Mean Squared Error = $MSE") - println(s"PCA Mean Squared Error = $MSE_pca") - // $example off$ - - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 65603252c4384..eaf1dacd0160a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -62,7 +62,7 @@ object PowerIterationClusteringExample { maxIterations: Int = 15 ) extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("PowerIterationClusteringExample") { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala index 8b789277774af..1b5d919a047e8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala @@ -25,7 +25,7 @@ import org.apache.spark.mllib.fpm.PrefixSpan object PrefixSpanExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("PrefixSpanExample") val sc = new SparkContext(conf) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala index 7ccbb5a0640cd..aee12a1b4751f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD */ object RandomRDDGeneration { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName(s"RandomRDDGeneration") val sc = new SparkContext(conf) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala index ea13ec05e2fad..2845028dd0814 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala @@ -25,7 +25,7 @@ import org.apache.spark.mllib.recommendation.{ALS, Rating} import org.apache.spark.sql.SparkSession object RankingMetricsExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("RankingMetricsExample") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala deleted file mode 100644 index 76cfb804e18f3..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -// scalastyle:off println - -package org.apache.spark.examples.mllib - -// $example on$ -import org.apache.spark.mllib.evaluation.RegressionMetrics -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD} -// $example off$ -import org.apache.spark.sql.SparkSession - -@deprecated("Use ml.regression.LinearRegression and the resulting model summary for metrics", - "2.0.0") -object RegressionMetricsExample { - def main(args: Array[String]): Unit = { - val spark = SparkSession - .builder - .appName("RegressionMetricsExample") - .getOrCreate() - // $example on$ - // Load the data - val data = spark - .read.format("libsvm").load("data/mllib/sample_linear_regression_data.txt") - .rdd.map(row => LabeledPoint(row.getDouble(0), row.get(1).asInstanceOf[Vector])) - .cache() - - // Build the model - val numIterations = 100 - val model = LinearRegressionWithSGD.train(data, numIterations) - - // Get predictions - val valuesAndPreds = data.map{ point => - val prediction = model.predict(point.features) - (prediction, point.label) - } - - // Instantiate metrics object - val metrics = new RegressionMetrics(valuesAndPreds) - - // Squared error - println(s"MSE = ${metrics.meanSquaredError}") - println(s"RMSE = ${metrics.rootMeanSquaredError}") - - // R-squared - println(s"R-squared = ${metrics.r2}") - - // Mean absolute error - println(s"MAE = ${metrics.meanAbsoluteError}") - - // Explained variance - println(s"Explained variance = ${metrics.explainedVariance}") - // $example off$ - - spark.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala index ba3deae5d688f..fdde47d60c544 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -35,7 +35,7 @@ object SampledRDDs { case class Params(input: String = "data/mllib/sample_binary_classification_data.txt") extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("SampledRDDs") { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala index 694c3bb18b045..ba16e8f5ff347 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD object SimpleFPGrowth { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("SimpleFPGrowth") val sc = new SparkContext(conf) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index b76add2f9bc99..b501f4db2efbb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -40,7 +40,7 @@ object SparseNaiveBayes { numFeatures: Int = -1, lambda: Double = 1.0) extends AbstractParams[Params] - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val defaultParams = Params() val parser = new OptionParser[Params]("SparseNaiveBayes") { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala index 7888af79f87f4..5186f599d9628 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala @@ -52,7 +52,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} */ object StreamingKMeansExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length != 5) { System.err.println( "Usage: StreamingKMeansExample " + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala index a8b144a197229..4c72f444ff9ec 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala @@ -46,7 +46,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} */ object StreamingLogisticRegression { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length != 4) { System.err.println( diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala index ae4dee24c6474..f60b10a02274b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala @@ -44,7 +44,7 @@ import org.apache.spark.util.Utils */ object StreamingTestExample { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length != 3) { // scalastyle:off println System.err.println( diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala index 071d341b81614..6b839f3f4ac1e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -35,7 +35,7 @@ import org.apache.spark.mllib.linalg.distributed.RowMatrix * represents a 3-by-2 matrix, whose first row is (0.5, 1.0). */ object TallSkinnyPCA { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length != 1) { System.err.println("Usage: TallSkinnyPCA ") System.exit(1) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala index 8ae6de16d80e7..8874c2eda3d2e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -35,7 +35,7 @@ import org.apache.spark.mllib.linalg.distributed.RowMatrix * represents a 3-by-2 matrix, whose first row is (0.5, 1.0). */ object TallSkinnySVD { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length != 1) { System.err.println("Usage: TallSkinnySVD ") System.exit(1) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index deaa9f252b9b0..4fd482d5b8bf7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession case class Record(key: Int, value: String) object RDDRelation { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { // $example on:init_session$ val spark = SparkSession .builder diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index c7b6a50f0ae7c..d4c05e5ad9944 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -24,7 +24,7 @@ object SQLDataSourceExample { case class Person(name: String, age: Long) - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val spark = SparkSession .builder() .appName("Spark SQL data sources example") diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala index 678cbc64aff1f..fde281087c267 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala @@ -34,7 +34,7 @@ object SparkSQLExample { case class Person(name: String, age: Long) // $example off:create_ds$ - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { // $example on:init_session$ val spark = SparkSession .builder() diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala index a832276602b88..3be8a3862f39c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala @@ -28,7 +28,7 @@ object SparkHiveExample { case class Record(key: Int, value: String) // $example off:spark_hive$ - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { // When working with Hive, one must instantiate `SparkSession` with Hive support, including // connectivity to a persistent Hive metastore, support for Hive serdes, and Hive user-defined // functions. Users who do not have an existing Hive deployment can still enable Hive support. diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala index de477c5ce8161..6dbc70bd141f3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.SparkSession * localhost 9999` */ object StructuredNetworkWordCount { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 2) { System.err.println("Usage: StructuredNetworkWordCount ") System.exit(1) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala index b4dad21dd75b0..4ba2c6bc68918 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.functions._ */ object StructuredNetworkWordCountWindowed { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 3) { System.err.println("Usage: StructuredNetworkWordCountWindowed " + " []") diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index fc3f8fa53c7ae..0f47deaf1021b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -37,7 +37,7 @@ import org.apache.spark.streaming.receiver.Receiver * `$ bin/run-example org.apache.spark.examples.streaming.CustomReceiver localhost 9999` */ object CustomReceiver { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 2) { System.err.println("Usage: CustomReceiver ") System.exit(1) @@ -64,20 +64,20 @@ object CustomReceiver { class CustomReceiver(host: String, port: Int) extends Receiver[String](StorageLevel.MEMORY_AND_DISK_2) { - def onStart() { + def onStart(): Unit = { // Start the thread that receives data over a connection new Thread("Socket Receiver") { - override def run() { receive() } + override def run(): Unit = { receive() } }.start() } - def onStop() { + def onStop(): Unit = { // There is nothing much to do as the thread calling receive() // is designed to stop by itself isStopped() returns false } /** Create a socket connection and receive data until receiver is stopped */ - private def receive() { + private def receive(): Unit = { var socket: Socket = null var userInput: String = null try { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index 3024b59480099..6fdb37194ea7d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -37,7 +37,7 @@ import org.apache.spark.streaming.kafka010._ * consumer-group topic1,topic2 */ object DirectKafkaWordCount { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 3) { System.err.println(s""" |Usage: DirectKafkaWordCount diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKerberizedKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKerberizedKafkaWordCount.scala index b68a59873a8fe..6a35ce9b2a293 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKerberizedKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKerberizedKafkaWordCount.scala @@ -76,7 +76,7 @@ import org.apache.spark.streaming.kafka010._ * using SASL_SSL in production. */ object DirectKerberizedKafkaWordCount { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 3) { System.err.println(s""" |Usage: DirectKerberizedKafkaWordCount diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala index 1f282d437dc38..19dc7a3cce0ac 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala @@ -33,7 +33,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} * Then create a text file in `localdir` and the words in the file will get counted. */ object HdfsWordCount { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 1) { System.err.println("Usage: HdfsWordCount ") System.exit(1) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala index 15b57fccb4076..26bb51dde3a1d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala @@ -34,7 +34,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} * `$ bin/run-example org.apache.spark.examples.streaming.NetworkWordCount localhost 9999` */ object NetworkWordCount { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 2) { System.err.println("Usage: NetworkWordCount ") System.exit(1) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala index 19bacd449787b..09eeaf9fa4496 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala @@ -25,7 +25,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} object QueueStream { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { StreamingExamples.setStreamingLogLevels() val sparkConf = new SparkConf().setAppName("QueueStream") diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index 437ccf0898d7c..a20abd6e9d12e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -37,7 +37,7 @@ import org.apache.spark.util.IntParam * is the Spark Streaming batch duration in milliseconds. */ object RawNetworkGrep { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length != 4) { System.err.println("Usage: RawNetworkGrep ") System.exit(1) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index f018f3a26d2e9..243c22e71275c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -139,7 +139,7 @@ object RecoverableNetworkWordCount { ssc } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length != 4) { System.err.println(s"Your arguments were ${args.mkString("[", ", ", "]")}") System.err.println( diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala index 787bbec73b28f..778be7baaeeac 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -38,7 +38,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext, Time} */ object SqlNetworkWordCount { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 2) { System.err.println("Usage: NetworkWordCount ") System.exit(1) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 2811e67009fb0..46f01edf7deec 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -35,7 +35,7 @@ import org.apache.spark.streaming._ * org.apache.spark.examples.streaming.StatefulNetworkWordCount localhost 9999` */ object StatefulNetworkWordCount { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 2) { System.err.println("Usage: StatefulNetworkWordCount ") System.exit(1) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala index b00f32fb25243..073f9728c68af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala @@ -25,7 +25,7 @@ import org.apache.spark.internal.Logging object StreamingExamples extends Logging { /** Set reasonable logging levels for streaming if the user has not configured log4j. */ - def setStreamingLogLevels() { + def setStreamingLogLevels(): Unit = { val log4jInitialized = Logger.getRootLogger.getAllAppenders.hasMoreElements if (!log4jInitialized) { // We first log something to initialize Spark's default logging, then we override the diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 2108bc63edea2..7234f30e7d267 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -81,7 +81,7 @@ object PageViewGenerator { new PageView(page, status, zipCode, id).toString() } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length != 2) { System.err.println("Usage: PageViewGenerator ") System.exit(1) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index b8e7c7e9e9152..b51bfacabf4aa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -35,7 +35,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} */ // scalastyle:on object PageViewStream { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length != 3) { System.err.println("Usage: PageViewStream ") System.err.println(" must be one of pageCounts, slidingPageCounts," + diff --git a/external/avro/benchmarks/AvroReadBenchmark-jdk11-results.txt b/external/avro/benchmarks/AvroReadBenchmark-jdk11-results.txt new file mode 100644 index 0000000000000..94137a691e4aa --- /dev/null +++ b/external/avro/benchmarks/AvroReadBenchmark-jdk11-results.txt @@ -0,0 +1,122 @@ +================================================================================================ +SQL Single Numeric Column Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum 2995 3081 121 5.3 190.4 1.0X + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum 2865 2881 23 5.5 182.2 1.0X + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum 2919 2936 23 5.4 185.6 1.0X + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum 3148 3262 161 5.0 200.1 1.0X + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum 2651 2721 99 5.9 168.5 1.0X + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum 2782 2854 103 5.7 176.9 1.0X + + +================================================================================================ +Int and String Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum of columns 4531 4583 73 2.3 432.1 1.0X + + +================================================================================================ +Partitioned Table Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Data column 3084 3105 30 5.1 196.1 1.0X +Partition column 3143 3164 30 5.0 199.8 1.0X +Both columns 3272 3339 94 4.8 208.1 0.9X + + +================================================================================================ +Repeated String Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum of string length 3249 3318 98 3.2 309.8 1.0X + + +================================================================================================ +String with Nulls Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum of string length 5308 5335 38 2.0 506.2 1.0X + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum of string length 4405 4429 33 2.4 420.1 1.0X + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum of string length 3256 3309 75 3.2 310.5 1.0X + + +================================================================================================ +Single Column Scan From Wide Columns +================================================================================================ + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum of single column 5230 5290 85 0.2 4987.4 1.0X + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Single Column Scan from 200 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum of single column 10206 10329 174 0.1 9733.1 1.0X + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Single Column Scan from 300 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum of single column 15333 15365 46 0.1 14622.3 1.0X + + diff --git a/external/avro/benchmarks/AvroReadBenchmark-results.txt b/external/avro/benchmarks/AvroReadBenchmark-results.txt index 7900fea453b10..7b008a312c320 100644 --- a/external/avro/benchmarks/AvroReadBenchmark-results.txt +++ b/external/avro/benchmarks/AvroReadBenchmark-results.txt @@ -2,121 +2,121 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Sum 2774 / 2815 5.7 176.4 1.0X +SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum 3067 3132 91 5.1 195.0 1.0X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Sum 2761 / 2777 5.7 175.5 1.0X +SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum 2927 2929 3 5.4 186.1 1.0X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Sum 2783 / 2870 5.7 176.9 1.0X +SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum 2928 2990 87 5.4 186.2 1.0X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Sum 3256 / 3266 4.8 207.0 1.0X +SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum 3374 3447 104 4.7 214.5 1.0X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Sum 2841 / 2867 5.5 180.6 1.0X +SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum 2896 2901 7 5.4 184.1 1.0X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Sum 2981 / 2996 5.3 189.5 1.0X +SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum 3004 3006 3 5.2 191.0 1.0X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Sum of columns 4781 / 4783 2.2 456.0 1.0X +Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum of columns 4814 4830 22 2.2 459.1 1.0X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Data column 3372 / 3386 4.7 214.4 1.0X -Partition column 3035 / 3064 5.2 193.0 1.1X -Both columns 3445 / 3461 4.6 219.1 1.0X +Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Data column 3361 3362 1 4.7 213.7 1.0X +Partition column 2999 3013 20 5.2 190.7 1.1X +Both columns 3613 3615 2 4.4 229.7 0.9X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Sum of string length 3395 / 3401 3.1 323.8 1.0X +Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum of string length 3415 3416 1 3.1 325.7 1.0X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Sum of string length 5580 / 5624 1.9 532.2 1.0X +String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum of string length 5535 5536 2 1.9 527.8 1.0X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan (50.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Sum of string length 4622 / 4623 2.3 440.8 1.0X +String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum of string length 4567 4575 11 2.3 435.6 1.0X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan (95.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Sum of string length 3238 / 3241 3.2 308.8 1.0X +String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum of string length 3248 3268 29 3.2 309.7 1.0X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Sum of single column 5472 / 5484 0.2 5218.8 1.0X +Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum of single column 5486 5497 15 0.2 5232.0 1.0X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Sum of single column 10680 / 10701 0.1 10185.1 1.0X +Single Column Scan from 200 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum of single column 10682 10746 90 0.1 10186.8 1.0X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Sum of single column 16143 / 16238 0.1 15394.9 1.0X +Single Column Scan from 300 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Sum of single column 16177 16177 0 0.1 15427.7 1.0X diff --git a/external/avro/benchmarks/AvroWriteBenchmark-jdk11-results.txt b/external/avro/benchmarks/AvroWriteBenchmark-jdk11-results.txt new file mode 100644 index 0000000000000..2cf1835013821 --- /dev/null +++ b/external/avro/benchmarks/AvroWriteBenchmark-jdk11-results.txt @@ -0,0 +1,10 @@ +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Avro writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Output Single Int Column 3026 3142 164 5.2 192.4 1.0X +Output Single Double Column 3157 3260 145 5.0 200.7 1.0X +Output Int and String Column 6123 6190 94 2.6 389.3 0.5X +Output Partitions 5197 5733 758 3.0 330.4 0.6X +Output Buckets 7074 7285 298 2.2 449.7 0.4X + diff --git a/external/avro/benchmarks/AvroWriteBenchmark-results.txt b/external/avro/benchmarks/AvroWriteBenchmark-results.txt index fb2a77333eec5..20f6ae9099a4d 100644 --- a/external/avro/benchmarks/AvroWriteBenchmark-results.txt +++ b/external/avro/benchmarks/AvroWriteBenchmark-results.txt @@ -1,10 +1,10 @@ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -Avro writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Output Single Int Column 3213 / 3373 4.9 204.3 1.0X -Output Single Double Column 3313 / 3345 4.7 210.7 1.0X -Output Int and String Column 7303 / 7316 2.2 464.3 0.4X -Output Partitions 5309 / 5691 3.0 337.5 0.6X -Output Buckets 7031 / 7557 2.2 447.0 0.5X +Avro writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Output Single Int Column 3080 3137 82 5.1 195.8 1.0X +Output Single Double Column 3595 3595 0 4.4 228.6 0.9X +Output Int and String Column 7491 7504 18 2.1 476.3 0.4X +Output Partitions 5518 5663 205 2.9 350.8 0.6X +Output Buckets 7467 7581 161 2.1 474.7 0.4X diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala index 3171f1e08b4fc..c6f52d676422c 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.v2.avro import org.apache.spark.sql.avro.AvroFileFormat +import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 -import org.apache.spark.sql.sources.v2.Table import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala index 243af7da47003..0397d15aed924 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala @@ -31,10 +31,10 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.execution.datasources.v2.{EmptyPartitionReader, FilePartitionReaderFactory, PartitionReaderWithPartitionValues} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.PartitionReader import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala index 6ec351080a118..e1268ac2ce581 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala @@ -21,9 +21,9 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScan -import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala index 815da2bd92d44..e36c71ef4b1f7 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.v2.avro import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder -import org.apache.spark.sql.sources.v2.reader.Scan import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala index a781624aa61aa..765e5727d944a 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala @@ -22,9 +22,9 @@ import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.SparkSession import org.apache.spark.sql.avro.AvroUtils +import org.apache.spark.sql.connector.write.WriteBuilder import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.v2.FileTable -import org.apache.spark.sql.sources.v2.writer.WriteBuilder import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index cf88981b1efbd..dc60cfe41ca7a 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -1036,7 +1036,7 @@ abstract class AvroSuite extends QueryTest with SharedSparkSession { (TimestampType, LONG), (DecimalType(4, 2), BYTES) ) - def assertException(f: () => AvroSerializer) { + def assertException(f: () => AvroSerializer): Unit = { val message = intercept[org.apache.spark.sql.avro.IncompatibleSchemaException] { f() }.getMessage diff --git a/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala b/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala index f2f7d650066fb..a16126ae24246 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala @@ -22,7 +22,6 @@ import scala.util.Random import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.types._ /** @@ -36,7 +35,7 @@ import org.apache.spark.sql.types._ * Results will be written to "benchmarks/AvroReadBenchmark-results.txt". * }}} */ -object AvroReadBenchmark extends SqlBasedBenchmark with SQLHelper { +object AvroReadBenchmark extends SqlBasedBenchmark { def withTempTable(tableNames: String*)(f: => Unit): Unit = { try f finally tableNames.foreach(spark.catalog.dropTempView) } diff --git a/external/docker/spark-test/base/Dockerfile b/external/docker/spark-test/base/Dockerfile index c1fd630d0b665..5bec5d3f16548 100644 --- a/external/docker/spark-test/base/Dockerfile +++ b/external/docker/spark-test/base/Dockerfile @@ -25,7 +25,7 @@ RUN apt-get update && \ apt-get install -y less openjdk-8-jre-headless iproute2 vim-tiny sudo openssh-server && \ rm -rf /var/lib/apt/lists/* -ENV SCALA_VERSION 2.12.8 +ENV SCALA_VERSION 2.12.10 ENV CDH_VERSION cdh4 ENV SCALA_HOME /opt/scala-$SCALA_VERSION ENV SPARK_HOME /opt/spark diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 0735f0a7b937f..693820da6af6b 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -46,6 +46,13 @@ ${project.version} provided + + org.apache.spark + spark-token-provider-kafka-0-10_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-core_${scala.binary.version} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala index 868edb5dcdc0c..6dd5af2389a81 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala @@ -68,7 +68,7 @@ private object JsonUtils { partOffsets.map { case (part, offset) => new TopicPartition(topic, part) -> offset } - }.toMap + } } catch { case NonFatal(x) => throw new IllegalArgumentException( @@ -76,12 +76,27 @@ private object JsonUtils { } } + def partitionTimestamps(str: String): Map[TopicPartition, Long] = { + try { + Serialization.read[Map[String, Map[Int, Long]]](str).flatMap { case (topic, partTimestamps) => + partTimestamps.map { case (part, timestamp) => + new TopicPartition(topic, part) -> timestamp + } + } + } catch { + case NonFatal(x) => + throw new IllegalArgumentException( + s"""Expected e.g. {"topicA": {"0": 123456789, "1": 123456789}, + |"topicB": {"0": 123456789, "1": 123456789}}, got $str""".stripMargin) + } + } + /** * Write per-TopicPartition offsets as json string */ def partitionOffsets(partitionOffsets: Map[TopicPartition, Long]): String = { val result = new HashMap[String, HashMap[Int, Long]]() - implicit val ordering = new Ordering[TopicPartition] { + implicit val order = new Ordering[TopicPartition] { override def compare(x: TopicPartition, y: TopicPartition): Int = { Ordering.Tuple2[String, Int].compare((x.topic, x.partition), (y.topic, y.partition)) } @@ -95,4 +110,9 @@ private object JsonUtils { } Serialization.write(result) } + + def partitionTimestamps(topicTimestamps: Map[TopicPartition, Long]): String = { + // For now it's same as partitionOffsets + partitionOffsets(topicTimestamps) + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala index 700414167f3ef..3006770f306c0 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala @@ -23,8 +23,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReaderFactory} - +import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory} private[kafka010] class KafkaBatch( strategy: ConsumerStrategy, @@ -32,7 +31,8 @@ private[kafka010] class KafkaBatch( specifiedKafkaParams: Map[String, String], failOnDataLoss: Boolean, startingOffsets: KafkaOffsetRangeLimit, - endingOffsets: KafkaOffsetRangeLimit) + endingOffsets: KafkaOffsetRangeLimit, + includeHeaders: Boolean) extends Batch with Logging { assert(startingOffsets != LatestOffsetRangeLimit, "Starting offset not allowed to be set to latest offsets.") @@ -59,8 +59,8 @@ private[kafka010] class KafkaBatch( // Leverage the KafkaReader to obtain the relevant partition offsets val (fromPartitionOffsets, untilPartitionOffsets) = { try { - (kafkaOffsetReader.fetchPartitionOffsets(startingOffsets), - kafkaOffsetReader.fetchPartitionOffsets(endingOffsets)) + (kafkaOffsetReader.fetchPartitionOffsets(startingOffsets, isStartingOffsets = true), + kafkaOffsetReader.fetchPartitionOffsets(endingOffsets, isStartingOffsets = false)) } finally { kafkaOffsetReader.close() } @@ -91,7 +91,7 @@ private[kafka010] class KafkaBatch( KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId) offsetRanges.map { range => new KafkaBatchInputPartition( - range, executorKafkaParams, pollTimeoutMs, failOnDataLoss) + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, includeHeaders) }.toArray } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala index 53b0b3c46854e..645b68b0c407a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala @@ -22,21 +22,21 @@ import java.{util => ju} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.sources.v2.reader._ - +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} /** A [[InputPartition]] for reading Kafka data in a batch based streaming query. */ private[kafka010] case class KafkaBatchInputPartition( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends InputPartition + failOnDataLoss: Boolean, + includeHeaders: Boolean) extends InputPartition private[kafka010] object KafkaBatchReaderFactory extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val p = partition.asInstanceOf[KafkaBatchInputPartition] KafkaBatchPartitionReader(p.offsetRange, p.executorKafkaParams, p.pollTimeoutMs, - p.failOnDataLoss) + p.failOnDataLoss, p.includeHeaders) } } @@ -45,12 +45,14 @@ private case class KafkaBatchPartitionReader( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends PartitionReader[InternalRow] with Logging { + failOnDataLoss: Boolean, + includeHeaders: Boolean) extends PartitionReader[InternalRow] with Logging { private val consumer = KafkaDataConsumer.acquire(offsetRange.topicPartition, executorKafkaParams) private val rangeToRead = resolveRange(offsetRange) - private val converter = new KafkaRecordToUnsafeRowConverter + private val unsafeRowProjector = new KafkaRecordToRowConverter() + .toUnsafeRowProjector(includeHeaders) private var nextOffset = rangeToRead.fromOffset private var nextRow: UnsafeRow = _ @@ -59,7 +61,7 @@ private case class KafkaBatchPartitionReader( if (nextOffset < rangeToRead.untilOffset) { val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss) if (record != null) { - nextRow = converter.toUnsafeRow(record) + nextRow = unsafeRowProjector(record) nextOffset = record.offset + 1 true } else { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala index 47ec07ae128d2..8e29e38b2a644 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery -import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.StructType /** diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala index a9c1181a01c51..0603ae39ba622 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala @@ -27,9 +27,9 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader, ContinuousPartitionReaderFactory, ContinuousStream, Offset, PartitionOffset} +import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.util.CaseInsensitiveStringMap /** @@ -56,6 +56,7 @@ class KafkaContinuousStream( private[kafka010] val pollTimeoutMs = options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512) + private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) // Initialized when creating reader factories. If this diverges from the partitions at the latest // offsets, we need to reconfigure. @@ -68,6 +69,8 @@ class KafkaContinuousStream( case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets(None)) case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) + case SpecificTimestampRangeLimit(p) => offsetReader.fetchSpecificTimestampBasedOffsets(p, + failsOnNoMatchingOffset = true) } logInfo(s"Initial offsets: $offsets") offsets @@ -88,7 +91,7 @@ class KafkaContinuousStream( if (deletedPartitions.nonEmpty) { val message = if ( offsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) { - s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}" + s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}" } else { s"$deletedPartitions are gone. Some data may have been missed." } @@ -102,7 +105,7 @@ class KafkaContinuousStream( startOffsets.toSeq.map { case (topicPartition, start) => KafkaContinuousInputPartition( - topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) + topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss, includeHeaders) }.toArray } @@ -153,19 +156,22 @@ class KafkaContinuousStream( * @param pollTimeoutMs The timeout for Kafka consumer polling. * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. + * @param includeHeaders Flag indicating whether to include Kafka records' headers. */ case class KafkaContinuousInputPartition( - topicPartition: TopicPartition, - startOffset: Long, - kafkaParams: ju.Map[String, Object], - pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends InputPartition + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean, + includeHeaders: Boolean) extends InputPartition object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory { override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { val p = partition.asInstanceOf[KafkaContinuousInputPartition] new KafkaContinuousPartitionReader( - p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs, p.failOnDataLoss) + p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs, + p.failOnDataLoss, p.includeHeaders) } } @@ -184,9 +190,11 @@ class KafkaContinuousPartitionReader( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousPartitionReader[InternalRow] { + failOnDataLoss: Boolean, + includeHeaders: Boolean) extends ContinuousPartitionReader[InternalRow] { private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams) - private val converter = new KafkaRecordToUnsafeRowConverter + private val unsafeRowProjector = new KafkaRecordToRowConverter() + .toUnsafeRowProjector(includeHeaders) private var nextKafkaOffset = startOffset private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ @@ -225,7 +233,7 @@ class KafkaContinuousPartitionReader( } override def get(): UnsafeRow = { - converter.toUnsafeRow(currentRecord) + unsafeRowProjector(currentRecord) } override def getOffset(): KafkaSourcePartitionOffset = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index 87036beb9a252..ca82c908f441b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -23,12 +23,13 @@ import java.util.concurrent.TimeoutException import scala.collection.JavaConverters._ +import org.apache.kafka.clients.CommonClientConfigs import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer, OffsetOutOfRangeException} import org.apache.kafka.common.TopicPartition import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.kafka010.KafkaConfigUpdater +import org.apache.spark.kafka010.{KafkaConfigUpdater, KafkaTokenClusterConf, KafkaTokenUtil} import org.apache.spark.sql.kafka010.KafkaDataConsumer.{AvailableOffsetRange, UNKNOWN_OFFSET} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.util.{ShutdownHookManager, UninterruptibleThread} @@ -46,6 +47,13 @@ private[kafka010] class InternalKafkaConsumer( val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + private[kafka010] val clusterConfig = KafkaTokenUtil.findMatchingTokenClusterConfig( + SparkEnv.get.conf, kafkaParams.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG) + .asInstanceOf[String]) + + // Kafka consumer is not able to give back the params instantiated with so we need to store it. + // It must be updated whenever a new consumer is created. + private[kafka010] var kafkaParamsWithSecurity: ju.Map[String, Object] = _ private val consumer = createConsumer() /** @@ -106,10 +114,10 @@ private[kafka010] class InternalKafkaConsumer( /** Create a KafkaConsumer to fetch records for `topicPartition` */ private def createConsumer(): KafkaConsumer[Array[Byte], Array[Byte]] = { - val updatedKafkaParams = KafkaConfigUpdater("executor", kafkaParams.asScala.toMap) - .setAuthenticationConfigIfNeeded() + kafkaParamsWithSecurity = KafkaConfigUpdater("executor", kafkaParams.asScala.toMap) + .setAuthenticationConfigIfNeeded(clusterConfig) .build() - val c = new KafkaConsumer[Array[Byte], Array[Byte]](updatedKafkaParams) + val c = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParamsWithSecurity) val tps = new ju.ArrayList[TopicPartition]() tps.add(topicPartition) c.assign(tps) @@ -516,13 +524,25 @@ private[kafka010] class KafkaDataConsumer( fetchedData.withNewPoll(records.listIterator, offsetAfterPoll) } - private def getOrRetrieveConsumer(): InternalKafkaConsumer = _consumer match { - case None => - _consumer = Option(consumerPool.borrowObject(cacheKey, kafkaParams)) - require(_consumer.isDefined, "borrowing consumer from pool must always succeed.") - _consumer.get + private[kafka010] def getOrRetrieveConsumer(): InternalKafkaConsumer = { + if (!_consumer.isDefined) { + retrieveConsumer() + } + require(_consumer.isDefined, "Consumer must be defined") + if (!KafkaTokenUtil.isConnectorUsingCurrentToken(_consumer.get.kafkaParamsWithSecurity, + _consumer.get.clusterConfig)) { + logDebug("Cached consumer uses an old delegation token, invalidating.") + releaseConsumer() + consumerPool.invalidateKey(cacheKey) + fetchedDataPool.invalidate(cacheKey) + retrieveConsumer() + } + _consumer.get + } - case Some(consumer) => consumer + private def retrieveConsumer(): Unit = { + _consumer = Option(consumerPool.borrowObject(cacheKey, kafkaParams)) + require(_consumer.isDefined, "borrowing consumer from pool must always succeed.") } private def getOrRetrieveFetchedData(offset: Long): FetchedData = _fetchedData match { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala index 884773452b2a5..3f8d3d2da5797 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala @@ -21,7 +21,7 @@ import java.{util => ju} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} /** * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala index 9cd16c8e16249..01f6ba4445162 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala @@ -26,10 +26,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} +import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset} import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchStream -import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset} +import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.UninterruptibleThread @@ -64,6 +64,8 @@ private[kafka010] class KafkaMicroBatchStream( private[kafka010] val maxOffsetsPerTrigger = Option(options.get( KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER)).map(_.toLong) + private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) + private val rangeCalculator = KafkaOffsetRangeCalculator(options) private var endPartitionOffsets: KafkaSourceOffset = _ @@ -112,7 +114,7 @@ private[kafka010] class KafkaMicroBatchStream( if (deletedPartitions.nonEmpty) { val message = if (kafkaOffsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) { - s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}" + s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}" } else { s"$deletedPartitions are gone. Some data may have been missed." } @@ -146,7 +148,8 @@ private[kafka010] class KafkaMicroBatchStream( // Generate factories based on the offset ranges offsetRanges.map { range => - KafkaBatchInputPartition(range, executorKafkaParams, pollTimeoutMs, failOnDataLoss) + KafkaBatchInputPartition(range, executorKafkaParams, pollTimeoutMs, + failOnDataLoss, includeHeaders) }.toArray } @@ -189,6 +192,8 @@ private[kafka010] class KafkaMicroBatchStream( KafkaSourceOffset(kafkaOffsetReader.fetchLatestOffsets(None)) case SpecificOffsetRangeLimit(p) => kafkaOffsetReader.fetchSpecificOffsets(p, reportDataLoss) + case SpecificTimestampRangeLimit(p) => + kafkaOffsetReader.fetchSpecificTimestampBasedOffsets(p, failsOnNoMatchingOffset = true) } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeLimit.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeLimit.scala index 80a026f4f5d73..d64b5d4f7e9e8 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeLimit.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeLimit.scala @@ -42,6 +42,13 @@ private[kafka010] case object LatestOffsetRangeLimit extends KafkaOffsetRangeLim private[kafka010] case class SpecificOffsetRangeLimit( partitionOffsets: Map[TopicPartition, Long]) extends KafkaOffsetRangeLimit +/** + * Represents the desire to bind to earliest offset which timestamp for the offset is equal or + * greater than specific timestamp. + */ +private[kafka010] case class SpecificTimestampRangeLimit( + topicTimestamps: Map[TopicPartition, Long]) extends KafkaOffsetRangeLimit + private[kafka010] object KafkaOffsetRangeLimit { /** * Used to denote offset range limits that are resolved via Kafka diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index f3effd5300a79..0179f4dd822f1 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -26,12 +26,11 @@ import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.util.control.NonFatal -import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer} +import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer, OffsetAndTimestamp} import org.apache.kafka.common.TopicPartition import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.types._ import org.apache.spark.util.{ThreadUtils, UninterruptibleThread} /** @@ -127,12 +126,14 @@ private[kafka010] class KafkaOffsetReader( * Fetch the partition offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]] and [[KafkaOffsetRangeLimit]]. */ - def fetchPartitionOffsets(offsetRangeLimit: KafkaOffsetRangeLimit): Map[TopicPartition, Long] = { + def fetchPartitionOffsets( + offsetRangeLimit: KafkaOffsetRangeLimit, + isStartingOffsets: Boolean): Map[TopicPartition, Long] = { def validateTopicPartitions(partitions: Set[TopicPartition], partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = { assert(partitions == partitionOffsets.keySet, "If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" + - "Use -1 for latest, -2 for earliest, if you don't care.\n" + + "Use -1 for latest, -2 for earliest.\n" + s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions}") logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets") partitionOffsets @@ -148,6 +149,9 @@ private[kafka010] class KafkaOffsetReader( }.toMap case SpecificOffsetRangeLimit(partitionOffsets) => validateTopicPartitions(partitions, partitionOffsets) + case SpecificTimestampRangeLimit(partitionTimestamps) => + fetchSpecificTimestampBasedOffsets(partitionTimestamps, + failsOnNoMatchingOffset = isStartingOffsets).partitionToOffsets } } @@ -162,23 +166,83 @@ private[kafka010] class KafkaOffsetReader( def fetchSpecificOffsets( partitionOffsets: Map[TopicPartition, Long], reportDataLoss: String => Unit): KafkaSourceOffset = { - val fetched = runUninterruptibly { - withRetriesWithoutInterrupt { - // Poll to get the latest assigned partitions - consumer.poll(0) - val partitions = consumer.assignment() + val fnAssertParametersWithPartitions: ju.Set[TopicPartition] => Unit = { partitions => + assert(partitions.asScala == partitionOffsets.keySet, + "If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" + + "Use -1 for latest, -2 for earliest, if you don't care.\n" + + s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions.asScala}") + logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets") + } - // Call `position` to wait until the potential offset request triggered by `poll(0)` is - // done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by - // `poll(0)` may reset offsets that should have been set by another request. - partitions.asScala.map(p => p -> consumer.position(p)).foreach(_ => {}) + val fnRetrievePartitionOffsets: ju.Set[TopicPartition] => Map[TopicPartition, Long] = { _ => + partitionOffsets + } - consumer.pause(partitions) - assert(partitions.asScala == partitionOffsets.keySet, - "If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" + - "Use -1 for latest, -2 for earliest, if you don't care.\n" + - s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions.asScala}") - logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets") + val fnAssertFetchedOffsets: Map[TopicPartition, Long] => Unit = { fetched => + partitionOffsets.foreach { + case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && + off != KafkaOffsetRangeLimit.EARLIEST => + if (fetched(tp) != off) { + reportDataLoss( + s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}") + } + case _ => + // no real way to check that beginning or end is reasonable + } + } + + fetchSpecificOffsets0(fnAssertParametersWithPartitions, fnRetrievePartitionOffsets, + fnAssertFetchedOffsets) + } + + def fetchSpecificTimestampBasedOffsets( + partitionTimestamps: Map[TopicPartition, Long], + failsOnNoMatchingOffset: Boolean): KafkaSourceOffset = { + val fnAssertParametersWithPartitions: ju.Set[TopicPartition] => Unit = { partitions => + assert(partitions.asScala == partitionTimestamps.keySet, + "If starting/endingOffsetsByTimestamp contains specific offsets, you must specify all " + + s"topics. Specified: ${partitionTimestamps.keySet} Assigned: ${partitions.asScala}") + logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionTimestamps") + } + + val fnRetrievePartitionOffsets: ju.Set[TopicPartition] => Map[TopicPartition, Long] = { _ => { + val converted = partitionTimestamps.map { case (tp, timestamp) => + tp -> java.lang.Long.valueOf(timestamp) + }.asJava + + val offsetForTime: ju.Map[TopicPartition, OffsetAndTimestamp] = + consumer.offsetsForTimes(converted) + + offsetForTime.asScala.map { case (tp, offsetAndTimestamp) => + if (failsOnNoMatchingOffset) { + assert(offsetAndTimestamp != null, "No offset matched from request of " + + s"topic-partition $tp and timestamp ${partitionTimestamps(tp)}.") + } + + if (offsetAndTimestamp == null) { + tp -> KafkaOffsetRangeLimit.LATEST + } else { + tp -> offsetAndTimestamp.offset() + } + }.toMap + } + } + + val fnAssertFetchedOffsets: Map[TopicPartition, Long] => Unit = { _ => } + + fetchSpecificOffsets0(fnAssertParametersWithPartitions, fnRetrievePartitionOffsets, + fnAssertFetchedOffsets) + } + + private def fetchSpecificOffsets0( + fnAssertParametersWithPartitions: ju.Set[TopicPartition] => Unit, + fnRetrievePartitionOffsets: ju.Set[TopicPartition] => Map[TopicPartition, Long], + fnAssertFetchedOffsets: Map[TopicPartition, Long] => Unit): KafkaSourceOffset = { + val fetched = partitionsAssignedToConsumer { + partitions => { + fnAssertParametersWithPartitions(partitions) + + val partitionOffsets = fnRetrievePartitionOffsets(partitions) partitionOffsets.foreach { case (tp, KafkaOffsetRangeLimit.LATEST) => @@ -187,22 +251,15 @@ private[kafka010] class KafkaOffsetReader( consumer.seekToBeginning(ju.Arrays.asList(tp)) case (tp, off) => consumer.seek(tp, off) } + partitionOffsets.map { case (tp, _) => tp -> consumer.position(tp) } } } - partitionOffsets.foreach { - case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && - off != KafkaOffsetRangeLimit.EARLIEST => - if (fetched(tp) != off) { - reportDataLoss( - s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}") - } - case _ => - // no real way to check that beginning or end is reasonable - } + fnAssertFetchedOffsets(fetched) + KafkaSourceOffset(fetched) } @@ -210,20 +267,15 @@ private[kafka010] class KafkaOffsetReader( * Fetch the earliest offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]]. */ - def fetchEarliestOffsets(): Map[TopicPartition, Long] = runUninterruptibly { - withRetriesWithoutInterrupt { - // Poll to get the latest assigned partitions - consumer.poll(0) - val partitions = consumer.assignment() - consumer.pause(partitions) - logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the beginning") + def fetchEarliestOffsets(): Map[TopicPartition, Long] = partitionsAssignedToConsumer( + partitions => { + logDebug("Seeking to the beginning") consumer.seekToBeginning(partitions) val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap logDebug(s"Got earliest offsets for partition : $partitionOffsets") partitionOffsets - } - } + }, fetchingEarliestOffset = true) /** * Fetch the latest offsets for the topic partitions that are indicated @@ -240,19 +292,9 @@ private[kafka010] class KafkaOffsetReader( * distinguish this with KAFKA-7703, so we just return whatever we get from Kafka after retrying. */ def fetchLatestOffsets( - knownOffsets: Option[PartitionOffsetMap]): PartitionOffsetMap = runUninterruptibly { - withRetriesWithoutInterrupt { - // Poll to get the latest assigned partitions - consumer.poll(0) - val partitions = consumer.assignment() - - // Call `position` to wait until the potential offset request triggered by `poll(0)` is - // done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by - // `poll(0)` may reset offsets that should have been set by another request. - partitions.asScala.map(p => p -> consumer.position(p)).foreach(_ => {}) - - consumer.pause(partitions) - logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the end.") + knownOffsets: Option[PartitionOffsetMap]): PartitionOffsetMap = + partitionsAssignedToConsumer { partitions => { + logDebug("Seeking to the end.") if (knownOffsets.isEmpty) { consumer.seekToEnd(partitions) @@ -316,25 +358,40 @@ private[kafka010] class KafkaOffsetReader( if (newPartitions.isEmpty) { Map.empty[TopicPartition, Long] } else { - runUninterruptibly { - withRetriesWithoutInterrupt { - // Poll to get the latest assigned partitions - consumer.poll(0) - val partitions = consumer.assignment() - consumer.pause(partitions) - logDebug(s"\tPartitions assigned to consumer: $partitions") - - // Get the earliest offset of each partition - consumer.seekToBeginning(partitions) - val partitionOffsets = newPartitions.filter { p => - // When deleting topics happen at the same time, some partitions may not be in - // `partitions`. So we need to ignore them - partitions.contains(p) - }.map(p => p -> consumer.position(p)).toMap - logDebug(s"Got earliest offsets for new partitions: $partitionOffsets") - partitionOffsets - } + partitionsAssignedToConsumer(partitions => { + // Get the earliest offset of each partition + consumer.seekToBeginning(partitions) + val partitionOffsets = newPartitions.filter { p => + // When deleting topics happen at the same time, some partitions may not be in + // `partitions`. So we need to ignore them + partitions.contains(p) + }.map(p => p -> consumer.position(p)).toMap + logDebug(s"Got earliest offsets for new partitions: $partitionOffsets") + partitionOffsets + }, fetchingEarliestOffset = true) + } + } + + private def partitionsAssignedToConsumer( + body: ju.Set[TopicPartition] => Map[TopicPartition, Long], + fetchingEarliestOffset: Boolean = false) + : Map[TopicPartition, Long] = runUninterruptibly { + + withRetriesWithoutInterrupt { + // Poll to get the latest assigned partitions + consumer.poll(0) + val partitions = consumer.assignment() + + if (!fetchingEarliestOffset) { + // Call `position` to wait until the potential offset request triggered by `poll(0)` is + // done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by + // `poll(0)` may reset offsets that should have been set by another request. + partitions.asScala.map(p => p -> consumer.position(p)).foreach(_ => {}) } + + consumer.pause(partitions) + logDebug(s"Partitions assigned to consumer: $partitions.") + body(partitions) } } @@ -421,16 +478,3 @@ private[kafka010] class KafkaOffsetReader( _consumer = null // will automatically get reinitialized again } } - -private[kafka010] object KafkaOffsetReader { - - def kafkaSchema: StructType = StructType(Seq( - StructField("key", BinaryType), - StructField("value", BinaryType), - StructField("topic", StringType), - StructField("partition", IntegerType), - StructField("offset", LongType), - StructField("timestamp", TimestampType), - StructField("timestampType", IntegerType) - )) -} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala new file mode 100644 index 0000000000000..aed099c142bc3 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.sql.Timestamp + +import scala.collection.JavaConverters._ + +import org.apache.kafka.clients.consumer.ConsumerRecord + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** A simple class for converting Kafka ConsumerRecord to InternalRow/UnsafeRow */ +private[kafka010] class KafkaRecordToRowConverter { + import KafkaRecordToRowConverter._ + + private val toUnsafeRowWithoutHeaders = UnsafeProjection.create(schemaWithoutHeaders) + private val toUnsafeRowWithHeaders = UnsafeProjection.create(schemaWithHeaders) + + val toInternalRowWithoutHeaders: Record => InternalRow = + (cr: Record) => InternalRow( + cr.key, cr.value, UTF8String.fromString(cr.topic), cr.partition, cr.offset, + DateTimeUtils.fromJavaTimestamp(new Timestamp(cr.timestamp)), cr.timestampType.id + ) + + val toInternalRowWithHeaders: Record => InternalRow = + (cr: Record) => InternalRow( + cr.key, cr.value, UTF8String.fromString(cr.topic), cr.partition, cr.offset, + DateTimeUtils.fromJavaTimestamp(new Timestamp(cr.timestamp)), cr.timestampType.id, + if (cr.headers.iterator().hasNext) { + new GenericArrayData(cr.headers.iterator().asScala + .map(header => + InternalRow(UTF8String.fromString(header.key()), header.value()) + ).toArray) + } else { + null + } + ) + + def toUnsafeRowWithoutHeadersProjector: Record => UnsafeRow = + (cr: Record) => toUnsafeRowWithoutHeaders(toInternalRowWithoutHeaders(cr)) + + def toUnsafeRowWithHeadersProjector: Record => UnsafeRow = + (cr: Record) => toUnsafeRowWithHeaders(toInternalRowWithHeaders(cr)) + + def toUnsafeRowProjector(includeHeaders: Boolean): Record => UnsafeRow = { + if (includeHeaders) toUnsafeRowWithHeadersProjector else toUnsafeRowWithoutHeadersProjector + } +} + +private[kafka010] object KafkaRecordToRowConverter { + type Record = ConsumerRecord[Array[Byte], Array[Byte]] + + val headersType = ArrayType(StructType(Array( + StructField("key", StringType), + StructField("value", BinaryType)))) + + private val schemaWithoutHeaders = new StructType(Array( + StructField("key", BinaryType), + StructField("value", BinaryType), + StructField("topic", StringType), + StructField("partition", IntegerType), + StructField("offset", LongType), + StructField("timestamp", TimestampType), + StructField("timestampType", IntegerType) + )) + + private val schemaWithHeaders = + new StructType(schemaWithoutHeaders.fields :+ StructField("headers", headersType)) + + def kafkaSchema(includeHeaders: Boolean): StructType = { + if (includeHeaders) schemaWithHeaders else schemaWithoutHeaders + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala deleted file mode 100644 index 306ef10b775a9..0000000000000 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import org.apache.kafka.clients.consumer.ConsumerRecord - -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.unsafe.types.UTF8String - -/** A simple class for converting Kafka ConsumerRecord to UnsafeRow */ -private[kafka010] class KafkaRecordToUnsafeRowConverter { - private val rowWriter = new UnsafeRowWriter(7) - - def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = { - rowWriter.reset() - rowWriter.zeroOutNullBytes() - - if (record.key == null) { - rowWriter.setNullAt(0) - } else { - rowWriter.write(0, record.key) - } - if (record.value == null) { - rowWriter.setNullAt(1) - } else { - rowWriter.write(1, record.value) - } - rowWriter.write(2, UTF8String.fromString(record.topic)) - rowWriter.write(3, record.partition) - rowWriter.write(4, record.offset) - rowWriter.write( - 5, - DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp))) - rowWriter.write(6, record.timestampType.id) - rowWriter.getRow() - } -} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index dc7087821b10c..61479c992039b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -24,10 +24,9 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.UTF8String private[kafka010] class KafkaRelation( @@ -36,6 +35,7 @@ private[kafka010] class KafkaRelation( sourceOptions: CaseInsensitiveMap[String], specifiedKafkaParams: Map[String, String], failOnDataLoss: Boolean, + includeHeaders: Boolean, startingOffsets: KafkaOffsetRangeLimit, endingOffsets: KafkaOffsetRangeLimit) extends BaseRelation with TableScan with Logging { @@ -49,7 +49,9 @@ private[kafka010] class KafkaRelation( (sqlContext.sparkContext.conf.get(NETWORK_TIMEOUT) * 1000L).toString ).toLong - override def schema: StructType = KafkaOffsetReader.kafkaSchema + private val converter = new KafkaRecordToRowConverter() + + override def schema: StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders) override def buildScan(): RDD[Row] = { // Each running query should use its own group id. Otherwise, the query may be only assigned @@ -66,8 +68,8 @@ private[kafka010] class KafkaRelation( // Leverage the KafkaReader to obtain the relevant partition offsets val (fromPartitionOffsets, untilPartitionOffsets) = { try { - (kafkaOffsetReader.fetchPartitionOffsets(startingOffsets), - kafkaOffsetReader.fetchPartitionOffsets(endingOffsets)) + (kafkaOffsetReader.fetchPartitionOffsets(startingOffsets, isStartingOffsets = true), + kafkaOffsetReader.fetchPartitionOffsets(endingOffsets, isStartingOffsets = false)) } finally { kafkaOffsetReader.close() } @@ -100,18 +102,14 @@ private[kafka010] class KafkaRelation( // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. val executorKafkaParams = KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId) + val toInternalRow = if (includeHeaders) { + converter.toInternalRowWithHeaders + } else { + converter.toInternalRowWithoutHeaders + } val rdd = new KafkaSourceRDD( sqlContext.sparkContext, executorKafkaParams, offsetRanges, - pollTimeoutMs, failOnDataLoss).map { cr => - InternalRow( - cr.key, - cr.value, - UTF8String.fromString(cr.topic), - cr.partition, - cr.offset, - DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)), - cr.timestampType.id) - } + pollTimeoutMs, failOnDataLoss).map(toInternalRow) sqlContext.internalCreateDataFrame(rdd.setName("kafka"), schema).rdd } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index d1a35ec53bc94..e1392b6215d3a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -31,12 +31,11 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.kafka010.KafkaSource._ -import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String /** * A [[Source]] that reads data from Kafka using the following design. @@ -84,13 +83,15 @@ private[kafka010] class KafkaSource( private val sc = sqlContext.sparkContext - private val pollTimeoutMs = sourceOptions.getOrElse( - KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, - (sc.conf.get(NETWORK_TIMEOUT) * 1000L).toString - ).toLong + private val pollTimeoutMs = + sourceOptions.getOrElse(CONSUMER_POLL_TIMEOUT, (sc.conf.get(NETWORK_TIMEOUT) * 1000L).toString) + .toLong private val maxOffsetsPerTrigger = - sourceOptions.get(KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER).map(_.toLong) + sourceOptions.get(MAX_OFFSET_PER_TRIGGER).map(_.toLong) + + private val includeHeaders = + sourceOptions.getOrElse(INCLUDE_HEADERS, "false").toBoolean /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only @@ -104,6 +105,8 @@ private[kafka010] class KafkaSource( case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets(None)) case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss) + case SpecificTimestampRangeLimit(p) => + kafkaReader.fetchSpecificTimestampBasedOffsets(p, failsOnNoMatchingOffset = true) } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") @@ -113,7 +116,9 @@ private[kafka010] class KafkaSource( private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None - override def schema: StructType = KafkaOffsetReader.kafkaSchema + private val converter = new KafkaRecordToRowConverter() + + override def schema: StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders) /** Returns the maximum available offset for this source. */ override def getOffset: Option[Offset] = { @@ -223,7 +228,7 @@ private[kafka010] class KafkaSource( val deletedPartitions = fromPartitionOffsets.keySet.diff(untilPartitionOffsets.keySet) if (deletedPartitions.nonEmpty) { val message = if (kafkaReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) { - s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}" + s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}" } else { s"$deletedPartitions are gone. Some data may have been missed." } @@ -267,16 +272,14 @@ private[kafka010] class KafkaSource( }.toArray // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. - val rdd = new KafkaSourceRDD( - sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss).map { cr => - InternalRow( - cr.key, - cr.value, - UTF8String.fromString(cr.topic), - cr.partition, - cr.offset, - DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)), - cr.timestampType.id) + val rdd = if (includeHeaders) { + new KafkaSourceRDD( + sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss) + .map(converter.toInternalRowWithHeaders) + } else { + new KafkaSourceRDD( + sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss) + .map(converter.toInternalRowWithoutHeaders) } logInfo("GetBatch generating RDD of offset range: " + diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index 90d70439c5329..b9674a30aee39 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition +import org.apache.spark.sql.connector.read.streaming.PartitionOffset import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} -import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index c3f0be4be96e2..c15f08d78741d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,14 +30,13 @@ import org.apache.spark.internal.Logging import org.apache.spark.kafka010.KafkaConfigUpdater import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.read.{Batch, Scan, ScanBuilder} +import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} +import org.apache.spark.sql.connector.write.{BatchWrite, WriteBuilder} +import org.apache.spark.sql.connector.write.streaming.StreamingWrite import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.TableCapability._ -import org.apache.spark.sql.sources.v2.reader.{Batch, Scan, ScanBuilder} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} -import org.apache.spark.sql.sources.v2.writer.{BatchWrite, WriteBuilder} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -70,7 +69,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister val caseInsensitiveParameters = CaseInsensitiveMap(parameters) validateStreamOptions(caseInsensitiveParameters) require(schema.isEmpty, "Kafka source has a fixed schema and cannot be set with a custom one") - (shortName(), KafkaOffsetReader.kafkaSchema) + val includeHeaders = caseInsensitiveParameters.getOrElse(INCLUDE_HEADERS, "false").toBoolean + (shortName(), KafkaRecordToRowConverter.kafkaSchema(includeHeaders)) } override def createSource( @@ -89,7 +89,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveParameters) val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( - caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + caseInsensitiveParameters, STARTING_OFFSETS_BY_TIMESTAMP_OPTION_KEY, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) val kafkaOffsetReader = new KafkaOffsetReader( strategy(caseInsensitiveParameters), @@ -108,7 +109,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } override def getTable(options: CaseInsensitiveStringMap): KafkaTable = { - new KafkaTable + val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) + new KafkaTable(includeHeaders) } /** @@ -125,19 +127,24 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveParameters) val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( - caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) + caseInsensitiveParameters, STARTING_OFFSETS_BY_TIMESTAMP_OPTION_KEY, + STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) assert(startingRelationOffsets != LatestOffsetRangeLimit) val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( - caseInsensitiveParameters, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + caseInsensitiveParameters, ENDING_OFFSETS_BY_TIMESTAMP_OPTION_KEY, + ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) assert(endingRelationOffsets != EarliestOffsetRangeLimit) + val includeHeaders = caseInsensitiveParameters.getOrElse(INCLUDE_HEADERS, "false").toBoolean + new KafkaRelation( sqlContext, strategy(caseInsensitiveParameters), sourceOptions = caseInsensitiveParameters, specifiedKafkaParams = specifiedKafkaParams, failOnDataLoss = failOnDataLoss(caseInsensitiveParameters), + includeHeaders = includeHeaders, startingOffsets = startingRelationOffsets, endingOffsets = endingRelationOffsets) } @@ -317,13 +324,17 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // Stream specific options params.get(ENDING_OFFSETS_OPTION_KEY).map(_ => throw new IllegalArgumentException("ending offset not valid in streaming queries")) + params.get(ENDING_OFFSETS_BY_TIMESTAMP_OPTION_KEY).map(_ => + throw new IllegalArgumentException("ending timestamp not valid in streaming queries")) + validateGeneralOptions(params) } private def validateBatchOptions(params: CaseInsensitiveMap[String]) = { // Batch specific options KafkaSourceProvider.getKafkaOffsetRangeLimit( - params, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) match { + params, STARTING_OFFSETS_BY_TIMESTAMP_OPTION_KEY, STARTING_OFFSETS_OPTION_KEY, + EarliestOffsetRangeLimit) match { case EarliestOffsetRangeLimit => // good to go case LatestOffsetRangeLimit => throw new IllegalArgumentException("starting offset can't be latest " + @@ -335,10 +346,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister "be latest for batch queries on Kafka") case _ => // ignore } + case _: SpecificTimestampRangeLimit => // good to go } KafkaSourceProvider.getKafkaOffsetRangeLimit( - params, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) match { + params, ENDING_OFFSETS_BY_TIMESTAMP_OPTION_KEY, ENDING_OFFSETS_OPTION_KEY, + LatestOffsetRangeLimit) match { case EarliestOffsetRangeLimit => throw new IllegalArgumentException("ending offset can't be earliest " + "for batch queries on Kafka") @@ -350,6 +363,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister "earliest for batch queries on Kafka") case _ => // ignore } + case _: SpecificTimestampRangeLimit => // good to go } validateGeneralOptions(params) @@ -360,13 +374,14 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - class KafkaTable extends Table with SupportsRead with SupportsWrite { + class KafkaTable(includeHeaders: Boolean) extends Table with SupportsRead with SupportsWrite { override def name(): String = "KafkaTable" - override def schema(): StructType = KafkaOffsetReader.kafkaSchema + override def schema(): StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders) override def capabilities(): ju.Set[TableCapability] = { + import TableCapability._ // ACCEPT_ANY_SCHEMA is needed because of the following reasons: // * Kafka writer validates the schema instead of the SQL analyzer (the schema is fixed) // * Read schema differs from write schema (please see Kafka integration guide) @@ -403,8 +418,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } class KafkaScan(options: CaseInsensitiveStringMap) extends Scan { + val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) - override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema + override def readSchema(): StructType = { + KafkaRecordToRowConverter.kafkaSchema(includeHeaders) + } override def toBatch(): Batch = { val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap) @@ -412,10 +430,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions) val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( - caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) + caseInsensitiveOptions, STARTING_OFFSETS_BY_TIMESTAMP_OPTION_KEY, + STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( - caseInsensitiveOptions, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + caseInsensitiveOptions, ENDING_OFFSETS_BY_TIMESTAMP_OPTION_KEY, + ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) new KafkaBatch( strategy(caseInsensitiveOptions), @@ -423,7 +443,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister specifiedKafkaParams, failOnDataLoss(caseInsensitiveOptions), startingRelationOffsets, - endingRelationOffsets) + endingRelationOffsets, + includeHeaders) } override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { @@ -437,7 +458,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions) val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( - caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + caseInsensitiveOptions, STARTING_OFFSETS_BY_TIMESTAMP_OPTION_KEY, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) val kafkaOffsetReader = new KafkaOffsetReader( strategy(caseInsensitiveOptions), @@ -465,7 +487,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions) val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( - caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + caseInsensitiveOptions, STARTING_OFFSETS_BY_TIMESTAMP_OPTION_KEY, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) val kafkaOffsetReader = new KafkaOffsetReader( strategy(caseInsensitiveOptions), @@ -491,6 +514,8 @@ private[kafka010] object KafkaSourceProvider extends Logging { private val STRATEGY_OPTION_KEYS = Set(SUBSCRIBE, SUBSCRIBE_PATTERN, ASSIGN) private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" + private[kafka010] val STARTING_OFFSETS_BY_TIMESTAMP_OPTION_KEY = "startingoffsetsbytimestamp" + private[kafka010] val ENDING_OFFSETS_BY_TIMESTAMP_OPTION_KEY = "endingoffsetsbytimestamp" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" private[kafka010] val MIN_PARTITIONS_OPTION_KEY = "minpartitions" private[kafka010] val MAX_OFFSET_PER_TRIGGER = "maxoffsetspertrigger" @@ -498,6 +523,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { private[kafka010] val FETCH_OFFSET_RETRY_INTERVAL_MS = "fetchoffset.retryintervalms" private[kafka010] val CONSUMER_POLL_TIMEOUT = "kafkaconsumer.polltimeoutms" private val GROUP_ID_PREFIX = "groupidprefix" + private[kafka010] val INCLUDE_HEADERS = "includeheaders" val TOPIC_OPTION_KEY = "topic" @@ -533,15 +559,20 @@ private[kafka010] object KafkaSourceProvider extends Logging { def getKafkaOffsetRangeLimit( params: CaseInsensitiveMap[String], + offsetByTimestampOptionKey: String, offsetOptionKey: String, defaultOffsets: KafkaOffsetRangeLimit): KafkaOffsetRangeLimit = { - params.get(offsetOptionKey).map(_.trim) match { - case Some(offset) if offset.toLowerCase(Locale.ROOT) == "latest" => - LatestOffsetRangeLimit - case Some(offset) if offset.toLowerCase(Locale.ROOT) == "earliest" => - EarliestOffsetRangeLimit - case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json)) - case None => defaultOffsets + params.get(offsetByTimestampOptionKey).map(_.trim) match { + case Some(json) => SpecificTimestampRangeLimit(JsonUtils.partitionTimestamps(json)) + case None => + params.get(offsetOptionKey).map(_.trim) match { + case Some(offset) if offset.toLowerCase(Locale.ROOT) == "latest" => + LatestOffsetRangeLimit + case Some(offset) if offset.toLowerCase(Locale.ROOT) == "earliest" => + EarliestOffsetRangeLimit + case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json)) + case None => defaultOffsets + } } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala index 6dd1d2984a96e..2b50b771e694e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery -import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType /** diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 041fac7717635..b423ddc959c1b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -19,9 +19,13 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} +import scala.collection.JavaConverters._ + import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata} +import org.apache.kafka.common.header.Header +import org.apache.kafka.common.header.internals.RecordHeader -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} import org.apache.spark.sql.types.{BinaryType, StringType} @@ -88,7 +92,17 @@ private[kafka010] abstract class KafkaRowWriter( throw new NullPointerException(s"null topic present in the data. Use the " + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") } - val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + val record = if (projectedRow.isNullAt(3)) { + new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value) + } else { + val headerArray = projectedRow.getArray(3) + val headers = (0 until headerArray.numElements()).map { i => + val struct = headerArray.getStruct(i, 2) + new RecordHeader(struct.getUTF8String(0).toString, struct.getBinary(1)) + .asInstanceOf[Header] + } + new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value, headers.asJava) + } producer.send(record, callback) } @@ -131,9 +145,26 @@ private[kafka010] abstract class KafkaRowWriter( throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " + s"attribute unsupported type ${t.catalogString}") } + val headersExpression = inputSchema + .find(_.name == KafkaWriter.HEADERS_ATTRIBUTE_NAME).getOrElse( + Literal(CatalystTypeConverters.convertToCatalyst(null), + KafkaRecordToRowConverter.headersType) + ) + headersExpression.dataType match { + case KafkaRecordToRowConverter.headersType => // good + case t => + throw new IllegalStateException(s"${KafkaWriter.HEADERS_ATTRIBUTE_NAME} " + + s"attribute unsupported type ${t.catalogString}") + } UnsafeProjection.create( - Seq(topicExpression, Cast(keyExpression, BinaryType), - Cast(valueExpression, BinaryType)), inputSchema) + Seq( + topicExpression, + Cast(keyExpression, BinaryType), + Cast(valueExpression, BinaryType), + headersExpression + ), + inputSchema + ) } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index e1a9191cc5a84..bbb060356f730 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -21,9 +21,10 @@ import java.{util => ju} import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} -import org.apache.spark.sql.types.{BinaryType, StringType} +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.types.{BinaryType, MapType, StringType} import org.apache.spark.util.Utils /** @@ -39,6 +40,7 @@ private[kafka010] object KafkaWriter extends Logging { val TOPIC_ATTRIBUTE_NAME: String = "topic" val KEY_ATTRIBUTE_NAME: String = "key" val VALUE_ATTRIBUTE_NAME: String = "value" + val HEADERS_ATTRIBUTE_NAME: String = "headers" override def toString: String = "KafkaWriter" @@ -75,6 +77,15 @@ private[kafka010] object KafkaWriter extends Logging { throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " + s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}") } + schema.find(_.name == HEADERS_ATTRIBUTE_NAME).getOrElse( + Literal(CatalystTypeConverters.convertToCatalyst(null), + KafkaRecordToRowConverter.headersType) + ).dataType match { + case KafkaRecordToRowConverter.headersType => // good + case _ => + throw new AnalysisException(s"$HEADERS_ATTRIBUTE_NAME attribute type " + + s"must be a ${KafkaRecordToRowConverter.headersType.catalogString}") + } } def write( diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/commits/0 b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/commits/0 new file mode 100644 index 0000000000000..9c1e3021c3ead --- /dev/null +++ b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/commits/0 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":0} \ No newline at end of file diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/metadata b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/metadata new file mode 100644 index 0000000000000..f1b5ab7aa17f0 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/metadata @@ -0,0 +1 @@ +{"id":"fc415a71-f0a2-4c3c-aeaf-f9e258c3f726"} \ No newline at end of file diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/offsets/0 b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/offsets/0 new file mode 100644 index 0000000000000..5dbadea57acbe --- /dev/null +++ b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1568508285207,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"5"}} +{"spark-test-topic-2b8619f5-d3c4-4c2d-b5d1-8d9d9458aa62":{"2":3,"4":3,"1":3,"3":3,"0":3}} \ No newline at end of file diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/sources/0/0 b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/sources/0/0 new file mode 100644 index 0000000000000..8cf9f8e009ce8 Binary files /dev/null and b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/sources/0/0 differ diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/0/1.delta b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/0/1.delta new file mode 100644 index 0000000000000..5815bbdcc2467 Binary files /dev/null and b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/0/1.delta differ diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/1/1.delta b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/1/1.delta new file mode 100644 index 0000000000000..e1a065b2b1c78 Binary files /dev/null and b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/1/1.delta differ diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/2/1.delta b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/2/1.delta new file mode 100644 index 0000000000000..cce14294e0044 Binary files /dev/null and b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/2/1.delta differ diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/3/1.delta b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/3/1.delta new file mode 100644 index 0000000000000..57063019503bc Binary files /dev/null and b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/3/1.delta differ diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/4/1.delta b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/4/1.delta new file mode 100644 index 0000000000000..e8b1e4bdc8dba Binary files /dev/null and b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/4/1.delta differ diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala index 80f9a1b410d2c..122fe752615ad 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.kafka010 +import java.{util => ju} +import java.nio.charset.StandardCharsets import java.util.concurrent.{Executors, TimeUnit} import scala.collection.JavaConverters._ @@ -29,10 +31,14 @@ import org.apache.kafka.common.serialization.ByteArrayDeserializer import org.scalatest.PrivateMethodTester import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.kafka010.KafkaDelegationTokenTest import org.apache.spark.sql.kafka010.KafkaDataConsumer.CacheKey import org.apache.spark.sql.test.SharedSparkSession -class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester { +class KafkaDataConsumerSuite + extends SharedSparkSession + with PrivateMethodTester + with KafkaDelegationTokenTest { protected var testUtils: KafkaTestUtils = _ private val topic = "topic" + Random.nextInt() @@ -65,6 +71,8 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester private var consumerPool: InternalKafkaConsumerPool = _ override def beforeEach(): Unit = { + super.beforeEach() + fetchedDataPool = { val fetchedDataPoolMethod = PrivateMethod[FetchedDataPool]('fetchedDataPool) KafkaDataConsumer.invokePrivate(fetchedDataPoolMethod()) @@ -91,53 +99,93 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester test("new KafkaDataConsumer instance in case of Task retry") { try { val kafkaParams = getKafkaParams() - val key = new CacheKey(groupId, topicPartition) + val key = CacheKey(groupId, topicPartition) val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null) TaskContext.setTaskContext(context1) - val consumer1 = KafkaDataConsumer.acquire(topicPartition, kafkaParams) - - // any method call which requires consumer is necessary - consumer1.getAvailableOffsetRange() - - val consumer1Underlying = consumer1._consumer - assert(consumer1Underlying.isDefined) - - consumer1.release() - - assert(consumerPool.size(key) === 1) - // check whether acquired object is available in pool - val pooledObj = consumerPool.borrowObject(key, kafkaParams) - assert(consumer1Underlying.get.eq(pooledObj)) - consumerPool.returnObject(pooledObj) + val consumer1Underlying = initSingleConsumer(kafkaParams, key) val context2 = new TaskContextImpl(0, 0, 0, 0, 1, null, null, null) TaskContext.setTaskContext(context2) - val consumer2 = KafkaDataConsumer.acquire(topicPartition, kafkaParams) - - // any method call which requires consumer is necessary - consumer2.getAvailableOffsetRange() + val consumer2Underlying = initSingleConsumer(kafkaParams, key) - val consumer2Underlying = consumer2._consumer - assert(consumer2Underlying.isDefined) // here we expect different consumer as pool will invalidate for task reattempt - assert(consumer2Underlying.get.ne(consumer1Underlying.get)) + assert(consumer2Underlying.ne(consumer1Underlying)) + } finally { + TaskContext.unset() + } + } - consumer2.release() + test("same KafkaDataConsumer instance in case of same token") { + try { + val kafkaParams = getKafkaParams() + val key = new CacheKey(groupId, topicPartition) + + val context = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null) + TaskContext.setTaskContext(context) + setSparkEnv( + Map( + s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers" -> bootStrapServers + ) + ) + addTokenToUGI(tokenService1, tokenId1, tokenPassword1) + val consumer1Underlying = initSingleConsumer(kafkaParams, key) + val consumer2Underlying = initSingleConsumer(kafkaParams, key) + + assert(consumer2Underlying.eq(consumer1Underlying)) + } finally { + TaskContext.unset() + } + } - // The first consumer should be removed from cache, but the consumer after invalidate - // should be cached. - assert(consumerPool.size(key) === 1) - val pooledObj2 = consumerPool.borrowObject(key, kafkaParams) - assert(consumer2Underlying.get.eq(pooledObj2)) - consumerPool.returnObject(pooledObj2) + test("new KafkaDataConsumer instance in case of token renewal") { + try { + val kafkaParams = getKafkaParams() + val key = new CacheKey(groupId, topicPartition) + + val context = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null) + TaskContext.setTaskContext(context) + setSparkEnv( + Map( + s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers" -> bootStrapServers + ) + ) + addTokenToUGI(tokenService1, tokenId1, tokenPassword1) + val consumer1Underlying = initSingleConsumer(kafkaParams, key) + addTokenToUGI(tokenService1, tokenId2, tokenPassword2) + val consumer2Underlying = initSingleConsumer(kafkaParams, key) + + assert(consumer2Underlying.ne(consumer1Underlying)) } finally { TaskContext.unset() } } + private def initSingleConsumer( + kafkaParams: ju.Map[String, Object], + key: CacheKey): InternalKafkaConsumer = { + val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams) + + // any method call which requires consumer is necessary + consumer.getOrRetrieveConsumer() + + val consumerUnderlying = consumer._consumer + assert(consumerUnderlying.isDefined) + + consumer.release() + + assert(consumerPool.size(key) === 1) + // check whether acquired object is available in pool + val pooledObj = consumerPool.borrowObject(key, kafkaParams) + assert(consumerUnderlying.get.eq(pooledObj)) + consumerPool.returnObject(pooledObj) + + consumerUnderlying.get + } + test("SPARK-23623: concurrent use of KafkaDataConsumer") { - val data: immutable.IndexedSeq[String] = prepareTestTopicHavingTestMessages(topic) + val data: immutable.IndexedSeq[(String, Seq[(String, Array[Byte])])] = + prepareTestTopicHavingTestMessages(topic) val topicPartition = new TopicPartition(topic, 0) val kafkaParams = getKafkaParams() @@ -157,10 +205,22 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester try { val range = consumer.getAvailableOffsetRange() val rcvd = range.earliest until range.latest map { offset => - val bytes = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false).value() - new String(bytes) + val record = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false) + val value = new String(record.value(), StandardCharsets.UTF_8) + val headers = record.headers().toArray.map(header => (header.key(), header.value())).toSeq + (value, headers) + } + data.zip(rcvd).foreach { case (expected, actual) => + // value + assert(expected._1 === actual._1) + // headers + expected._2.zip(actual._2).foreach { case (l, r) => + // header key + assert(l._1 === r._1) + // header value + assert(l._2 === r._2) + } } - assert(rcvd == data) } catch { case e: Throwable => error = e @@ -307,9 +367,12 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester } private def prepareTestTopicHavingTestMessages(topic: String) = { - val data = (1 to 1000).map(_.toString) + val data = (1 to 1000).map(i => (i.toString, Seq[(String, Array[Byte])]())) testUtils.createTopic(topic, 1) - testUtils.sendMessages(topic, data.toArray) + val messages = data.map { case (value, hdrs) => + new RecordBuilder(topic, value).headers(hdrs).build() + } + testUtils.sendMessages(messages) data } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenSuite.scala index 9850a91f34f63..306483825ae3b 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenSuite.scala @@ -82,7 +82,6 @@ class KafkaDelegationTokenSuite extends StreamTest with SharedSparkSession with .format("kafka") .option("checkpointLocation", checkpointDir.getCanonicalPath) .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.security.protocol", SASL_PLAINTEXT.name) .option("topic", topic) .start() @@ -99,7 +98,6 @@ class KafkaDelegationTokenSuite extends StreamTest with SharedSparkSession with val streamingDf = spark.readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.security.protocol", SASL_PLAINTEXT.name) .option("startingOffsets", s"earliest") .option("subscribe", topic) .load() diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index ae8a6886b2b4d..3ee59e57a6edf 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -28,13 +28,15 @@ import scala.collection.JavaConverters._ import scala.io.Source import scala.util.Random +import org.apache.commons.io.FileUtils import org.apache.kafka.clients.producer.{ProducerRecord, RecordMetadata} import org.apache.kafka.common.TopicPartition import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession} +import org.apache.spark.sql.{Dataset, ForeachWriter, Row, SparkSession} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.streaming._ @@ -42,11 +44,11 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.Utils abstract class KafkaSourceTest extends StreamTest with SharedSparkSession with KafkaTest { @@ -677,7 +679,8 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { }) } - private def testGroupId(groupIdKey: String, validateGroupId: (String, Iterable[String]) => Unit) { + private def testGroupId(groupIdKey: String, + validateGroupId: (String, Iterable[String]) => Unit): Unit = { // Tests code path KafkaSourceProvider.{sourceSchema(.), createSource(.)} // as well as KafkaOffsetReader.createConsumer(.) val topic = newTopic() @@ -1162,6 +1165,63 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { intercept[IllegalArgumentException] { test(minPartitions = "-1", 1, true) } } + test("default config of includeHeader doesn't break existing query from Spark 2.4") { + import testImplicits._ + + // This topic name is migrated from Spark 2.4.3 test run + val topic = "spark-test-topic-2b8619f5-d3c4-4c2d-b5d1-8d9d9458aa62" + // create same topic and messages as test run + testUtils.createTopic(topic, partitions = 5, overwrite = true) + testUtils.sendMessages(topic, Array(-20, -21, -22).map(_.toString), Some(0)) + testUtils.sendMessages(topic, Array(-10, -11, -12).map(_.toString), Some(1)) + testUtils.sendMessages(topic, Array(0, 1, 2).map(_.toString), Some(2)) + testUtils.sendMessages(topic, Array(10, 11, 12).map(_.toString), Some(3)) + testUtils.sendMessages(topic, Array(20, 21, 22).map(_.toString), Some(4)) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val headers = Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8))) + (31 to 35).map { num => + new RecordBuilder(topic, num.toString).partition(num - 31).headers(headers).build() + }.foreach { rec => testUtils.sendMessage(rec) } + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", topic) + .option("startingOffsets", "earliest") + .load() + + val query = kafka.dropDuplicates() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + .map(kv => kv._2.toInt + 1) + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + testStream(query)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + /* + Note: The checkpoint was generated using the following input in Spark version 2.4.3 + testUtils.createTopic(topic, partitions = 5, overwrite = true) + + testUtils.sendMessages(topic, Array(-20, -21, -22).map(_.toString), Some(0)) + testUtils.sendMessages(topic, Array(-10, -11, -12).map(_.toString), Some(1)) + testUtils.sendMessages(topic, Array(0, 1, 2).map(_.toString), Some(2)) + testUtils.sendMessages(topic, Array(10, 11, 12).map(_.toString), Some(3)) + testUtils.sendMessages(topic, Array(20, 21, 22).map(_.toString), Some(4)) + */ + makeSureGetOffsetCalled, + CheckNewAnswer(32, 33, 34, 35, 36) + ) + } } abstract class KafkaSourceSuiteBase extends KafkaSourceTest { @@ -1219,6 +1279,16 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { "failOnDataLoss" -> failOnDataLoss.toString) } + test(s"assign from specific timestamps (failOnDataLoss: $failOnDataLoss)") { + val topic = newTopic() + testFromSpecificTimestamps( + topic, + failOnDataLoss = failOnDataLoss, + addPartitions = false, + "assign" -> assignString(topic, 0 to 4), + "failOnDataLoss" -> failOnDataLoss.toString) + } + test(s"subscribing topic by name from latest offsets (failOnDataLoss: $failOnDataLoss)") { val topic = newTopic() testFromLatestOffsets( @@ -1242,6 +1312,12 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { testFromSpecificOffsets(topic, failOnDataLoss = failOnDataLoss, "subscribe" -> topic) } + test(s"subscribing topic by name from specific timestamps (failOnDataLoss: $failOnDataLoss)") { + val topic = newTopic() + testFromSpecificTimestamps(topic, failOnDataLoss = failOnDataLoss, addPartitions = true, + "subscribe" -> topic) + } + test(s"subscribing topic by pattern from latest offsets (failOnDataLoss: $failOnDataLoss)") { val topicPrefix = newTopic() val topic = topicPrefix + "-suffix" @@ -1270,6 +1346,17 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { failOnDataLoss = failOnDataLoss, "subscribePattern" -> s"$topicPrefix-.*") } + + test(s"subscribing topic by pattern from specific timestamps " + + s"(failOnDataLoss: $failOnDataLoss)") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-suffix" + testFromSpecificTimestamps( + topic, + failOnDataLoss = failOnDataLoss, + addPartitions = true, + "subscribePattern" -> s"$topicPrefix-.*") + } } test("bad source options") { @@ -1289,6 +1376,9 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { // Specifying an ending offset testBadOptions("endingOffsets" -> "latest")("Ending offset not valid in streaming queries") + testBadOptions("subscribe" -> "t", "endingOffsetsByTimestamp" -> "{\"t\": {\"0\": 1000}}")( + "Ending timestamp not valid in streaming queries") + // No strategy specified testBadOptions()("options must be specified", "subscribe", "subscribePattern") @@ -1337,7 +1427,8 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { (STARTING_OFFSETS_OPTION_KEY, """{"topic-A":{"0":23}}""", SpecificOffsetRangeLimit(Map(new TopicPartition("topic-A", 0) -> 23))))) { val offset = getKafkaOffsetRangeLimit( - CaseInsensitiveMap[String](Map(optionKey -> optionValue)), optionKey, answer) + CaseInsensitiveMap[String](Map(optionKey -> optionValue)), "dummy", optionKey, + answer) assert(offset === answer) } @@ -1345,7 +1436,7 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { (STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit), (ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit))) { val offset = getKafkaOffsetRangeLimit( - CaseInsensitiveMap[String](Map.empty), optionKey, answer) + CaseInsensitiveMap[String](Map.empty), "dummy", optionKey, answer) assert(offset === answer) } } @@ -1410,11 +1501,90 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { ) } + private def testFromSpecificTimestamps( + topic: String, + failOnDataLoss: Boolean, + addPartitions: Boolean, + options: (String, String)*): Unit = { + def sendMessages(topic: String, msgs: Seq[String], part: Int, ts: Long): Unit = { + val records = msgs.map { msg => + new RecordBuilder(topic, msg).partition(part).timestamp(ts).build() + } + testUtils.sendMessages(records) + } + + testUtils.createTopic(topic, partitions = 5) + + val firstTimestamp = System.currentTimeMillis() - 5000 + sendMessages(topic, Array(-20).map(_.toString), 0, firstTimestamp) + sendMessages(topic, Array(-10).map(_.toString), 1, firstTimestamp) + sendMessages(topic, Array(0, 1).map(_.toString), 2, firstTimestamp) + sendMessages(topic, Array(10, 11).map(_.toString), 3, firstTimestamp) + sendMessages(topic, Array(20, 21, 22).map(_.toString), 4, firstTimestamp) + + val secondTimestamp = firstTimestamp + 1000 + sendMessages(topic, Array(-21, -22).map(_.toString), 0, secondTimestamp) + sendMessages(topic, Array(-11, -12).map(_.toString), 1, secondTimestamp) + sendMessages(topic, Array(2).map(_.toString), 2, secondTimestamp) + sendMessages(topic, Array(12).map(_.toString), 3, secondTimestamp) + // no data after second timestamp for partition 4 + + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + // we intentionally starts from second timestamp, + // except for partition 4 - it starts from first timestamp + val startPartitionTimestamps: Map[TopicPartition, Long] = Map( + (0 to 3).map(new TopicPartition(topic, _) -> secondTimestamp): _* + ) ++ Map(new TopicPartition(topic, 4) -> firstTimestamp) + val startingTimestamps = JsonUtils.partitionTimestamps(startPartitionTimestamps) + + val reader = spark + .readStream + .format("kafka") + .option("startingOffsetsByTimestamp", startingTimestamps) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("failOnDataLoss", failOnDataLoss.toString) + options.foreach { case (k, v) => reader.option(k, v) } + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + testStream(mapped)( + makeSureGetOffsetCalled, + Execute { q => + val partitions = (0 to 4).map(new TopicPartition(topic, _)) + // wait to reach the last offset in every partition + q.awaitOffset( + 0, KafkaSourceOffset(partitions.map(tp => tp -> 3L).toMap), streamingTimeout.toMillis) + }, + CheckAnswer(-21, -22, -11, -12, 2, 12, 20, 21, 22), + StopStream, + StartStream(), + CheckAnswer(-21, -22, -11, -12, 2, 12, 20, 21, 22), // Should get the data back on recovery + StopStream, + AddKafkaData(Set(topic), 30, 31, 32), // Add data when stream is stopped + StartStream(), + CheckAnswer(-21, -22, -11, -12, 2, 12, 20, 21, 22, 30, 31, 32), // Should get the added data + AssertOnQuery("Add partitions") { query: StreamExecution => + if (addPartitions) setTopicPartitions(topic, 10, query) + true + }, + AddKafkaData(Set(topic), 40, 41, 42, 43, 44)(ensureDataInMultiplePartition = true), + CheckAnswer(-21, -22, -11, -12, 2, 12, 20, 21, 22, 30, 31, 32, 40, 41, 42, 43, 44), + StopStream + ) + } + test("Kafka column types") { val now = System.currentTimeMillis() val topic = newTopic() testUtils.createTopic(newTopic(), partitions = 1) - testUtils.sendMessages(topic, Array(1).map(_.toString)) + testUtils.sendMessage( + new RecordBuilder(topic, "1") + .headers(Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8)))).build() + ) val kafka = spark .readStream @@ -1423,6 +1593,7 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { .option("kafka.metadata.max.age.ms", "1") .option("startingOffsets", s"earliest") .option("subscribe", topic) + .option("includeHeaders", "true") .load() val query = kafka @@ -1445,6 +1616,21 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { // producer. So here we just use a low bound to make sure the internal conversion works. assert(row.getAs[java.sql.Timestamp]("timestamp").getTime >= now, s"Unexpected results: $row") assert(row.getAs[Int]("timestampType") === 0, s"Unexpected results: $row") + + def checkHeader(row: Row, expected: Seq[(String, Array[Byte])]): Unit = { + // array> + val headers = row.getList[Row](row.fieldIndex("headers")).asScala + assert(headers.length === expected.length) + + (0 until expected.length).foreach { idx => + val key = headers(idx).getAs[String]("key") + val value = headers(idx).getAs[Array[Byte]]("value") + assert(key === expected(idx)._1) + assert(value === expected(idx)._2) + } + } + + checkHeader(row, Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8)))) query.stop() } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index b4e1b78c7db4e..063e2e2bc8b77 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.kafka010 +import java.nio.charset.StandardCharsets.UTF_8 import java.util.Locale import java.util.concurrent.atomic.AtomicInteger +import scala.annotation.tailrec import scala.collection.JavaConverters._ import scala.util.Random @@ -27,11 +29,11 @@ import org.apache.kafka.clients.producer.ProducerRecord import org.apache.kafka.common.TopicPartition import org.apache.spark.SparkConf -import org.apache.spark.sql.QueryTest +import org.apache.spark.SparkException +import org.apache.spark.sql.{DataFrameReader, QueryTest} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -70,7 +72,8 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession protected def createDF( topic: String, withOptions: Map[String, String] = Map.empty[String, String], - brokerAddress: Option[String] = None) = { + brokerAddress: Option[String] = None, + includeHeaders: Boolean = false) = { val df = spark .read .format("kafka") @@ -80,7 +83,13 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession withOptions.foreach { case (key, value) => df.option(key, value) } - df.load().selectExpr("CAST(value AS STRING)") + if (includeHeaders) { + df.option("includeHeaders", "true") + df.load() + .selectExpr("CAST(value AS STRING)", "headers") + } else { + df.load().selectExpr("CAST(value AS STRING)") + } } test("explicit earliest to latest offsets") { @@ -147,6 +156,214 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession checkAnswer(df, (0 to 30).map(_.toString).toDF) } + test("default starting and ending offsets with headers") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessage( + new RecordBuilder(topic, "1").headers(Seq()).partition(0).build() + ) + testUtils.sendMessage( + new RecordBuilder(topic, "2").headers( + Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8)))).partition(1).build() + ) + testUtils.sendMessage( + new RecordBuilder(topic, "3").headers( + Seq(("e", "f".getBytes(UTF_8)), ("e", "g".getBytes(UTF_8)))).partition(2).build() + ) + + // Implicit offset values, should default to earliest and latest + val df = createDF(topic, includeHeaders = true) + // Test that we default to "earliest" and "latest" + checkAnswer(df, Seq(("1", null), + ("2", Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8)))), + ("3", Seq(("e", "f".getBytes(UTF_8)), ("e", "g".getBytes(UTF_8))))).toDF) + } + + test("timestamp provided for starting and ending") { + val (topic, timestamps) = prepareTimestampRelatedUnitTest + + // timestamp both presented: starting "first" ending "finalized" + verifyTimestampRelatedQueryResult({ df => + val startPartitionTimestamps: Map[TopicPartition, Long] = Map( + (0 to 2).map(new TopicPartition(topic, _) -> timestamps(1)): _*) + val startingTimestamps = JsonUtils.partitionTimestamps(startPartitionTimestamps) + + val endPartitionTimestamps = Map( + (0 to 2).map(new TopicPartition(topic, _) -> timestamps(2)): _*) + val endingTimestamps = JsonUtils.partitionTimestamps(endPartitionTimestamps) + + df.option("startingOffsetsByTimestamp", startingTimestamps) + .option("endingOffsetsByTimestamp", endingTimestamps) + }, topic, 10 to 19) + } + + test("timestamp provided for starting, offset provided for ending") { + val (topic, timestamps) = prepareTimestampRelatedUnitTest + + // starting only presented as "first", and ending presented as endingOffsets + verifyTimestampRelatedQueryResult({ df => + val startTopicTimestamps = Map( + (0 to 2).map(new TopicPartition(topic, _) -> timestamps.head): _*) + val startingTimestamps = JsonUtils.partitionTimestamps(startTopicTimestamps) + + val endPartitionOffsets = Map( + new TopicPartition(topic, 0) -> -1L, // -1 => latest + new TopicPartition(topic, 1) -> -1L, + new TopicPartition(topic, 2) -> 1L // explicit offset - take only first one + ) + val endingOffsets = JsonUtils.partitionOffsets(endPartitionOffsets) + + // so we here expect full of records from partition 0 and 1, and only the first record + // from partition 2 which is "2" + + df.option("startingOffsetsByTimestamp", startingTimestamps) + .option("endingOffsets", endingOffsets) + }, topic, (0 to 29).filterNot(_ % 3 == 2) ++ Seq(2)) + } + + test("timestamp provided for ending, offset provided for starting") { + val (topic, timestamps) = prepareTimestampRelatedUnitTest + + // ending only presented as "third", and starting presented as startingOffsets + verifyTimestampRelatedQueryResult({ df => + val startPartitionOffsets = Map( + new TopicPartition(topic, 0) -> -2L, // -2 => earliest + new TopicPartition(topic, 1) -> -2L, + new TopicPartition(topic, 2) -> 0L // explicit earliest + ) + val startingOffsets = JsonUtils.partitionOffsets(startPartitionOffsets) + + val endTopicTimestamps = Map( + (0 to 2).map(new TopicPartition(topic, _) -> timestamps(2)): _*) + val endingTimestamps = JsonUtils.partitionTimestamps(endTopicTimestamps) + + df.option("startingOffsets", startingOffsets) + .option("endingOffsetsByTimestamp", endingTimestamps) + }, topic, 0 to 19) + } + + test("timestamp provided for starting, ending not provided") { + val (topic, timestamps) = prepareTimestampRelatedUnitTest + + // starting only presented as "second", and ending not presented + verifyTimestampRelatedQueryResult({ df => + val startTopicTimestamps = Map( + (0 to 2).map(new TopicPartition(topic, _) -> timestamps(1)): _*) + val startingTimestamps = JsonUtils.partitionTimestamps(startTopicTimestamps) + + df.option("startingOffsetsByTimestamp", startingTimestamps) + }, topic, 10 to 29) + } + + test("timestamp provided for ending, starting not provided") { + val (topic, timestamps) = prepareTimestampRelatedUnitTest + + // ending only presented as "third", and starting not presented + verifyTimestampRelatedQueryResult({ df => + val endTopicTimestamps = Map( + (0 to 2).map(new TopicPartition(topic, _) -> timestamps(2)): _*) + val endingTimestamps = JsonUtils.partitionTimestamps(endTopicTimestamps) + + df.option("endingOffsetsByTimestamp", endingTimestamps) + }, topic, 0 to 19) + } + + test("no matched offset for timestamp - startingOffsets") { + val (topic, timestamps) = prepareTimestampRelatedUnitTest + + val e = intercept[SparkException] { + verifyTimestampRelatedQueryResult({ df => + // partition 2 will make query fail + val startTopicTimestamps = Map( + (0 to 1).map(new TopicPartition(topic, _) -> timestamps(1)): _*) ++ + Map(new TopicPartition(topic, 2) -> Long.MaxValue) + + val startingTimestamps = JsonUtils.partitionTimestamps(startTopicTimestamps) + + df.option("startingOffsetsByTimestamp", startingTimestamps) + }, topic, Seq.empty) + } + + @tailrec + def assertionErrorInExceptionChain(e: Throwable): Boolean = { + if (e.isInstanceOf[AssertionError]) { + true + } else if (e.getCause == null) { + false + } else { + assertionErrorInExceptionChain(e.getCause) + } + } + + assert(assertionErrorInExceptionChain(e), + "Cannot find expected AssertionError in chained exceptions") + } + + test("no matched offset for timestamp - endingOffsets") { + val (topic, timestamps) = prepareTimestampRelatedUnitTest + + // the query will run fine, since we allow no matching offset for timestamp + // if it's endingOffsets + // for partition 0 and 1, it only takes records between first and second timestamp + // for partition 2, it will take all records + verifyTimestampRelatedQueryResult({ df => + val endTopicTimestamps = Map( + (0 to 1).map(new TopicPartition(topic, _) -> timestamps(1)): _*) ++ + Map(new TopicPartition(topic, 2) -> Long.MaxValue) + + val endingTimestamps = JsonUtils.partitionTimestamps(endTopicTimestamps) + + df.option("endingOffsetsByTimestamp", endingTimestamps) + }, topic, (0 to 9) ++ (10 to 29).filter(_ % 3 == 2)) + } + + private def prepareTimestampRelatedUnitTest: (String, Seq[Long]) = { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + + def sendMessages(topic: String, msgs: Array[String], part: Int, ts: Long): Unit = { + val records = msgs.map { msg => + new RecordBuilder(topic, msg).partition(part).timestamp(ts).build() + } + testUtils.sendMessages(records) + } + + val firstTimestamp = System.currentTimeMillis() - 5000 + (0 to 2).foreach { partNum => + sendMessages(topic, (0 to 9).filter(_ % 3 == partNum) + .map(_.toString).toArray, partNum, firstTimestamp) + } + + val secondTimestamp = firstTimestamp + 1000 + (0 to 2).foreach { partNum => + sendMessages(topic, (10 to 19).filter(_ % 3 == partNum) + .map(_.toString).toArray, partNum, secondTimestamp) + } + + val thirdTimestamp = secondTimestamp + 1000 + (0 to 2).foreach { partNum => + sendMessages(topic, (20 to 29).filter(_ % 3 == partNum) + .map(_.toString).toArray, partNum, thirdTimestamp) + } + + val finalizedTimestamp = thirdTimestamp + 1000 + + (topic, Seq(firstTimestamp, secondTimestamp, thirdTimestamp, finalizedTimestamp)) + } + + private def verifyTimestampRelatedQueryResult( + optionFn: DataFrameReader => DataFrameReader, + topic: String, + expectation: Seq[Int]): Unit = { + val df = spark.read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + + val df2 = optionFn(df).load().selectExpr("CAST(value AS STRING)") + checkAnswer(df2, expectation.map(_.toString).toDF) + } + test("reuse same dataframe in query") { // This test ensures that we do not cache the Kafka Consumer in KafkaRelation val topic = newTopic() @@ -263,7 +480,8 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession }) } - private def testGroupId(groupIdKey: String, validateGroupId: (String, Iterable[String]) => Unit) { + private def testGroupId(groupIdKey: String, + validateGroupId: (String, Iterable[String]) => Unit): Unit = { // Tests code path KafkaSourceProvider.createRelation(.) val topic = newTopic() testUtils.createTopic(topic, partitions = 3) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 84ad41610cccd..d77b9a3b6a9e1 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.kafka010 +import java.nio.charset.StandardCharsets.UTF_8 import java.util.Locale import java.util.concurrent.atomic.AtomicInteger @@ -32,7 +33,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{BinaryType, DataType} +import org.apache.spark.sql.types.{BinaryType, DataType, StringType, StructField, StructType} abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with KafkaTest { protected var testUtils: KafkaTestUtils = _ @@ -59,13 +60,14 @@ abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" - protected def createKafkaReader(topic: String): DataFrame = { + protected def createKafkaReader(topic: String, includeHeaders: Boolean = false): DataFrame = { spark.read .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("startingOffsets", "earliest") .option("endingOffsets", "latest") .option("subscribe", topic) + .option("includeHeaders", includeHeaders.toString) .load() } } @@ -368,15 +370,52 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase { test("batch - write to kafka") { val topic = newTopic() testUtils.createTopic(topic) - val df = Seq("1", "2", "3", "4", "5").map(v => (topic, v)).toDF("topic", "value") + val data = Seq( + Row(topic, "1", Seq( + Row("a", "b".getBytes(UTF_8)) + )), + Row(topic, "2", Seq( + Row("c", "d".getBytes(UTF_8)), + Row("e", "f".getBytes(UTF_8)) + )), + Row(topic, "3", Seq( + Row("g", "h".getBytes(UTF_8)), + Row("g", "i".getBytes(UTF_8)) + )), + Row(topic, "4", null), + Row(topic, "5", Seq( + Row("j", "k".getBytes(UTF_8)), + Row("j", "l".getBytes(UTF_8)), + Row("m", "n".getBytes(UTF_8)) + )) + ) + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(data), + StructType(Seq(StructField("topic", StringType), StructField("value", StringType), + StructField("headers", KafkaRecordToRowConverter.headersType))) + ) + df.write .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("topic", topic) + .mode("append") .save() checkAnswer( - createKafkaReader(topic).selectExpr("CAST(value as STRING) value"), - Row("1") :: Row("2") :: Row("3") :: Row("4") :: Row("5") :: Nil) + createKafkaReader(topic, includeHeaders = true).selectExpr( + "CAST(value as STRING) value", "headers" + ), + Row("1", Seq(Row("a", "b".getBytes(UTF_8)))) :: + Row("2", Seq(Row("c", "d".getBytes(UTF_8)), Row("e", "f".getBytes(UTF_8)))) :: + Row("3", Seq(Row("g", "h".getBytes(UTF_8)), Row("g", "i".getBytes(UTF_8)))) :: + Row("4", null) :: + Row("5", Seq( + Row("j", "k".getBytes(UTF_8)), + Row("j", "l".getBytes(UTF_8)), + Row("m", "n".getBytes(UTF_8)))) :: + Nil + ) } test("batch - null topic field value, and no topic option") { @@ -385,12 +424,13 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase { df.write .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .mode("append") .save() } TestUtils.assertExceptionMsg(ex, "null topic present in the data") } - protected def testUnsupportedSaveModes(msg: (SaveMode) => String) { + protected def testUnsupportedSaveModes(msg: (SaveMode) => String): Unit = { val topic = newTopic() testUtils.createTopic(topic) val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value") @@ -419,6 +459,7 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("topic", topic) + .mode("append") .save() } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala index 8e6de88865e06..f7b00b31ebba0 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.mockito.Mockito.{mock, when} import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite} -import org.apache.spark.sql.sources.v2.reader.Scan +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.util.CaseInsensitiveStringMap class KafkaSourceProviderSuite extends SparkFunSuite { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index d7cb30f530396..bbb72bf9973e3 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.kafka010 import java.io.{File, IOException} import java.lang.{Integer => JInt} -import java.net.InetSocketAddress +import java.net.{InetAddress, InetSocketAddress} import java.nio.charset.StandardCharsets import java.util.{Collections, Map => JMap, Properties, UUID} import java.util.concurrent.TimeUnit @@ -41,6 +41,8 @@ import org.apache.kafka.clients.consumer.KafkaConsumer import org.apache.kafka.clients.producer._ import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.config.SaslConfigs +import org.apache.kafka.common.header.Header +import org.apache.kafka.common.header.internals.RecordHeader import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.security.auth.SecurityProtocol.{PLAINTEXT, SASL_PLAINTEXT} import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer} @@ -66,10 +68,13 @@ class KafkaTestUtils( private val JAVA_AUTH_CONFIG = "java.security.auth.login.config" + private val localCanonicalHostName = InetAddress.getLoopbackAddress().getCanonicalHostName() + logInfo(s"Local host name is $localCanonicalHostName") + private var kdc: MiniKdc = _ // Zookeeper related configurations - private val zkHost = "localhost" + private val zkHost = localCanonicalHostName private var zkPort: Int = 0 private val zkConnectionTimeout = 60000 private val zkSessionTimeout = 10000 @@ -78,12 +83,12 @@ class KafkaTestUtils( private var zkUtils: ZkUtils = _ // Kafka broker related configurations - private val brokerHost = "localhost" + private val brokerHost = localCanonicalHostName private var brokerPort = 0 private var brokerConf: KafkaConfig = _ private val brokerServiceName = "kafka" - private val clientUser = "client/localhost" + private val clientUser = s"client/$localCanonicalHostName" private var clientKeytabFile: File = _ // Kafka broker server @@ -137,17 +142,17 @@ class KafkaTestUtils( assert(kdcReady, "KDC should be set up beforehand") val baseDir = Utils.createTempDir() - val zkServerUser = "zookeeper/localhost" + val zkServerUser = s"zookeeper/$localCanonicalHostName" val zkServerKeytabFile = new File(baseDir, "zookeeper.keytab") kdc.createPrincipal(zkServerKeytabFile, zkServerUser) logDebug(s"Created keytab file: ${zkServerKeytabFile.getAbsolutePath()}") - val zkClientUser = "zkclient/localhost" + val zkClientUser = s"zkclient/$localCanonicalHostName" val zkClientKeytabFile = new File(baseDir, "zkclient.keytab") kdc.createPrincipal(zkClientKeytabFile, zkClientUser) logDebug(s"Created keytab file: ${zkClientKeytabFile.getAbsolutePath()}") - val kafkaServerUser = "kafka/localhost" + val kafkaServerUser = s"kafka/$localCanonicalHostName" val kafkaServerKeytabFile = new File(baseDir, "kafka.keytab") kdc.createPrincipal(kafkaServerKeytabFile, kafkaServerUser) logDebug(s"Created keytab file: ${kafkaServerKeytabFile.getAbsolutePath()}") @@ -348,38 +353,33 @@ class KafkaTestUtils( } } - /** Java-friendly function for sending messages to the Kafka broker */ - def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = { - sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*)) + def sendMessages(topic: String, msgs: Array[String]): Seq[(String, RecordMetadata)] = { + sendMessages(topic, msgs, None) } - /** Send the messages to the Kafka broker */ - def sendMessages(topic: String, messageToFreq: Map[String, Int]): Unit = { - val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray - sendMessages(topic, messages) + def sendMessages( + topic: String, + msgs: Array[String], + part: Option[Int]): Seq[(String, RecordMetadata)] = { + val records = msgs.map { msg => + val builder = new RecordBuilder(topic, msg) + part.foreach { p => builder.partition(p) } + builder.build() + } + sendMessages(records) } - /** Send the array of messages to the Kafka broker */ - def sendMessages(topic: String, messages: Array[String]): Seq[(String, RecordMetadata)] = { - sendMessages(topic, messages, None) + def sendMessage(msg: ProducerRecord[String, String]): Seq[(String, RecordMetadata)] = { + sendMessages(Array(msg)) } - /** Send the array of messages to the Kafka broker using specified partition */ - def sendMessages( - topic: String, - messages: Array[String], - partition: Option[Int]): Seq[(String, RecordMetadata)] = { + def sendMessages(msgs: Seq[ProducerRecord[String, String]]): Seq[(String, RecordMetadata)] = { producer = new KafkaProducer[String, String](producerConfiguration) val offsets = try { - messages.map { m => - val record = partition match { - case Some(p) => new ProducerRecord[String, String](topic, p, null, m) - case None => new ProducerRecord[String, String](topic, m) - } - val metadata = - producer.send(record).get(10, TimeUnit.SECONDS) - logInfo(s"\tSent $m to partition ${metadata.partition}, offset ${metadata.offset}") - (m, metadata) + msgs.map { msg => + val metadata = producer.send(msg).get(10, TimeUnit.SECONDS) + logInfo(s"\tSent ($msg) to partition ${metadata.partition}, offset ${metadata.offset}") + (msg.value(), metadata) } } finally { if (producer != null) { @@ -550,7 +550,7 @@ class KafkaTestUtils( zkUtils: ZkUtils, topic: String, numPartitions: Int, - servers: Seq[KafkaServer]) { + servers: Seq[KafkaServer]): Unit = { eventually(timeout(1.minute), interval(200.milliseconds)) { try { verifyTopicDeletion(topic, numPartitions, servers) @@ -613,7 +613,7 @@ class KafkaTestUtils( val actualPort = factory.getLocalPort - def shutdown() { + def shutdown(): Unit = { factory.shutdown() // The directories are not closed even if the ZooKeeper server is shut down. // Please see ZOOKEEPER-1844, which is fixed in 3.4.6+. It leads to test failures @@ -634,4 +634,3 @@ class KafkaTestUtils( } } } - diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/RecordBuilder.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/RecordBuilder.scala new file mode 100644 index 0000000000000..ef07798442e56 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/RecordBuilder.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.lang.{Integer => JInt, Long => JLong} + +import scala.collection.JavaConverters._ + +import org.apache.kafka.clients.producer.ProducerRecord +import org.apache.kafka.common.header.Header +import org.apache.kafka.common.header.internals.RecordHeader + +class RecordBuilder(topic: String, value: String) { + var _partition: Option[JInt] = None + var _timestamp: Option[JLong] = None + var _key: Option[String] = None + var _headers: Option[Seq[(String, Array[Byte])]] = None + + def partition(part: JInt): RecordBuilder = { + _partition = Some(part) + this + } + + def partition(part: Int): RecordBuilder = { + _partition = Some(part.intValue()) + this + } + + def timestamp(ts: JLong): RecordBuilder = { + _timestamp = Some(ts) + this + } + + def timestamp(ts: Long): RecordBuilder = { + _timestamp = Some(ts.longValue()) + this + } + + def key(k: String): RecordBuilder = { + _key = Some(k) + this + } + + def headers(hdrs: Seq[(String, Array[Byte])]): RecordBuilder = { + _headers = Some(hdrs) + this + } + + def build(): ProducerRecord[String, String] = { + val part = _partition.orNull + val ts = _timestamp.orNull + val k = _key.orNull + val hdrs = _headers.map { h => + h.map { case (k, v) => new RecordHeader(k, v).asInstanceOf[Header] } + }.map(_.asJava).orNull + + new ProducerRecord[String, String](topic, part, ts, k, value, hdrs) + } +} diff --git a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaConfigUpdater.scala b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaConfigUpdater.scala index 0c61045d6d487..f54ff0d146f7a 100644 --- a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaConfigUpdater.scala +++ b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaConfigUpdater.scala @@ -57,6 +57,12 @@ private[spark] case class KafkaConfigUpdater(module: String, kafkaParams: Map[St } def setAuthenticationConfigIfNeeded(): this.type = { + val clusterConfig = KafkaTokenUtil.findMatchingTokenClusterConfig(SparkEnv.get.conf, + kafkaParams(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG).asInstanceOf[String]) + setAuthenticationConfigIfNeeded(clusterConfig) + } + + def setAuthenticationConfigIfNeeded(clusterConfig: Option[KafkaTokenClusterConf]): this.type = { // There are multiple possibilities to log in and applied in the following order: // - JVM global security provided -> try to log in with JVM global security configuration // which can be configured for example with 'java.security.auth.login.config'. @@ -66,10 +72,9 @@ private[spark] case class KafkaConfigUpdater(module: String, kafkaParams: Map[St if (KafkaTokenUtil.isGlobalJaasConfigurationProvided) { logDebug("JVM global security configuration detected, using it for login.") } else { - val clusterConfig = KafkaTokenUtil.findMatchingToken(SparkEnv.get.conf, - map.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG).asInstanceOf[String]) clusterConfig.foreach { clusterConf => logDebug("Delegation token detected, using it for login.") + setIfUnset(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, clusterConf.securityProtocol) val jaasParams = KafkaTokenUtil.getTokenJaasParams(clusterConf) set(SaslConfigs.SASL_JAAS_CONFIG, jaasParams) require(clusterConf.tokenMechanism.startsWith("SCRAM"), diff --git a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenSparkConf.scala b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenSparkConf.scala index e1f3c800a51f8..ed4a6f1e34c55 100644 --- a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenSparkConf.scala +++ b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenSparkConf.scala @@ -57,6 +57,7 @@ private [kafka010] object KafkaTokenSparkConf extends Logging { val CLUSTERS_CONFIG_PREFIX = "spark.kafka.clusters." val DEFAULT_TARGET_SERVERS_REGEX = ".*" val DEFAULT_SASL_KERBEROS_SERVICE_NAME = "kafka" + val DEFAULT_SECURITY_PROTOCOL_CONFIG = SASL_SSL.name val DEFAULT_SASL_TOKEN_MECHANISM = "SCRAM-SHA-512" def getClusterConfig(sparkConf: SparkConf, identifier: String): KafkaTokenClusterConf = { @@ -72,7 +73,8 @@ private [kafka010] object KafkaTokenSparkConf extends Logging { s"${configPrefix}auth.${CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG}")), sparkClusterConf.getOrElse(s"target.${CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG}.regex", KafkaTokenSparkConf.DEFAULT_TARGET_SERVERS_REGEX), - sparkClusterConf.getOrElse(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, SASL_SSL.name), + sparkClusterConf.getOrElse(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, + DEFAULT_SECURITY_PROTOCOL_CONFIG), sparkClusterConf.getOrElse(SaslConfigs.SASL_KERBEROS_SERVICE_NAME, KafkaTokenSparkConf.DEFAULT_SASL_KERBEROS_SERVICE_NAME), sparkClusterConf.get(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG), diff --git a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenUtil.scala b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenUtil.scala index 39e3ac74a9aeb..0ebe98330b4ae 100644 --- a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenUtil.scala +++ b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenUtil.scala @@ -36,7 +36,7 @@ import org.apache.kafka.common.security.auth.SecurityProtocol.{SASL_PLAINTEXT, S import org.apache.kafka.common.security.scram.ScramLoginModule import org.apache.kafka.common.security.token.delegation.DelegationToken -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -241,8 +241,8 @@ private[spark] object KafkaTokenUtil extends Logging { "TOKENID", "HMAC", "OWNER", "RENEWERS", "ISSUEDATE", "EXPIRYDATE", "MAXDATE")) val tokenInfo = token.tokenInfo logDebug("%-15s %-15s %-15s %-25s %-15s %-15s %-15s".format( - REDACTION_REPLACEMENT_TEXT, tokenInfo.tokenId, + REDACTION_REPLACEMENT_TEXT, tokenInfo.owner, tokenInfo.renewersAsString, dateFormat.format(tokenInfo.issueTimestamp), @@ -251,7 +251,7 @@ private[spark] object KafkaTokenUtil extends Logging { } } - def findMatchingToken( + def findMatchingTokenClusterConfig( sparkConf: SparkConf, bootStrapServers: String): Option[KafkaTokenClusterConf] = { val tokens = UserGroupInformation.getCurrentUser().getCredentials.getAllTokens.asScala @@ -272,6 +272,7 @@ private[spark] object KafkaTokenUtil extends Logging { def getTokenJaasParams(clusterConf: KafkaTokenClusterConf): String = { val token = UserGroupInformation.getCurrentUser().getCredentials.getToken( getTokenService(clusterConf.identifier)) + require(token != null, s"Token for identifier ${clusterConf.identifier} must exist") val username = new String(token.getIdentifier) val password = new String(token.getPassword) @@ -288,4 +289,17 @@ private[spark] object KafkaTokenUtil extends Logging { params } + + def isConnectorUsingCurrentToken( + params: ju.Map[String, Object], + clusterConfig: Option[KafkaTokenClusterConf]): Boolean = { + if (params.containsKey(SaslConfigs.SASL_JAAS_CONFIG)) { + logDebug("Delegation token used by connector, checking if uses the latest token.") + val consumerJaasParams = params.get(SaslConfigs.SASL_JAAS_CONFIG).asInstanceOf[String] + require(clusterConfig.isDefined, "Delegation token must exist for this connector.") + getTokenJaasParams(clusterConfig.get) == consumerJaasParams + } else { + true + } + } } diff --git a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaConfigUpdaterSuite.scala b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaConfigUpdaterSuite.scala index 7a172892e778c..dc1e7cb8d979e 100644 --- a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaConfigUpdaterSuite.scala +++ b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaConfigUpdaterSuite.scala @@ -17,8 +17,13 @@ package org.apache.spark.kafka010 +import java.{util => ju} + +import scala.collection.JavaConverters._ + import org.apache.kafka.clients.CommonClientConfigs import org.apache.kafka.common.config.SaslConfigs +import org.apache.kafka.common.security.auth.SecurityProtocol.SASL_PLAINTEXT import org.apache.spark.SparkFunSuite @@ -62,36 +67,64 @@ class KafkaConfigUpdaterSuite extends SparkFunSuite with KafkaDelegationTokenTes } test("setAuthenticationConfigIfNeeded with global security should not set values") { - val params = Map.empty[String, String] + val params = Map( + CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG -> bootStrapServers + ) + setSparkEnv( + Map( + s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers" -> bootStrapServers + ) + ) setGlobalKafkaClientConfig() val updatedParams = KafkaConfigUpdater(testModule, params) .setAuthenticationConfigIfNeeded() .build() - assert(updatedParams.size() === 0) + assert(updatedParams.asScala === params) } test("setAuthenticationConfigIfNeeded with token should set values") { val params = Map( CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG -> bootStrapServers ) + testWithTokenSetValues(params) { updatedParams => + assert(updatedParams.get(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG) === + KafkaTokenSparkConf.DEFAULT_SECURITY_PROTOCOL_CONFIG) + } + } + + test("setAuthenticationConfigIfNeeded with token should not override user-defined protocol") { + val overrideProtocolName = SASL_PLAINTEXT.name + val params = Map( + CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG -> bootStrapServers, + CommonClientConfigs.SECURITY_PROTOCOL_CONFIG -> overrideProtocolName + ) + testWithTokenSetValues(params) { updatedParams => + assert(updatedParams.get(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG) === + overrideProtocolName) + } + } + + def testWithTokenSetValues(params: Map[String, String]) + (validate: (ju.Map[String, Object]) => Unit): Unit = { setSparkEnv( Map( s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers" -> bootStrapServers ) ) - addTokenToUGI(tokenService1) + addTokenToUGI(tokenService1, tokenId1, tokenPassword1) val updatedParams = KafkaConfigUpdater(testModule, params) .setAuthenticationConfigIfNeeded() .build() - assert(updatedParams.size() === 3) + assert(updatedParams.size() === 4) assert(updatedParams.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG) === bootStrapServers) assert(updatedParams.containsKey(SaslConfigs.SASL_JAAS_CONFIG)) assert(updatedParams.get(SaslConfigs.SASL_MECHANISM) === KafkaTokenSparkConf.DEFAULT_SASL_TOKEN_MECHANISM) + validate(updatedParams) } test("setAuthenticationConfigIfNeeded with invalid mechanism should throw exception") { @@ -104,7 +137,7 @@ class KafkaConfigUpdaterSuite extends SparkFunSuite with KafkaDelegationTokenTes s"spark.kafka.clusters.$identifier1.sasl.token.mechanism" -> "intentionally_invalid" ) ) - addTokenToUGI(tokenService1) + addTokenToUGI(tokenService1, tokenId1, tokenPassword1) val e = intercept[IllegalArgumentException] { KafkaConfigUpdater(testModule, params) diff --git a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaDelegationTokenTest.scala b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaDelegationTokenTest.scala index eebbf96afa470..19335f4221e40 100644 --- a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaDelegationTokenTest.scala +++ b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaDelegationTokenTest.scala @@ -37,8 +37,12 @@ trait KafkaDelegationTokenTest extends BeforeAndAfterEach { private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) - protected val tokenId = "tokenId" + ju.UUID.randomUUID().toString - protected val tokenPassword = "tokenPassword" + ju.UUID.randomUUID().toString + private var savedSparkEnv: SparkEnv = _ + + protected val tokenId1 = "tokenId" + ju.UUID.randomUUID().toString + protected val tokenPassword1 = "tokenPassword" + ju.UUID.randomUUID().toString + protected val tokenId2 = "tokenId" + ju.UUID.randomUUID().toString + protected val tokenPassword2 = "tokenPassword" + ju.UUID.randomUUID().toString protected val identifier1 = "cluster1" protected val identifier2 = "cluster2" @@ -72,11 +76,16 @@ trait KafkaDelegationTokenTest extends BeforeAndAfterEach { } } + override def beforeEach(): Unit = { + super.beforeEach() + savedSparkEnv = SparkEnv.get + } + override def afterEach(): Unit = { try { Configuration.setConfiguration(null) - UserGroupInformation.setLoginUser(null) - SparkEnv.set(null) + UserGroupInformation.reset() + SparkEnv.set(savedSparkEnv) } finally { super.afterEach() } @@ -86,7 +95,7 @@ trait KafkaDelegationTokenTest extends BeforeAndAfterEach { Configuration.setConfiguration(new KafkaJaasConfiguration) } - protected def addTokenToUGI(tokenService: Text): Unit = { + protected def addTokenToUGI(tokenService: Text, tokenId: String, tokenPassword: String): Unit = { val token = new Token[KafkaDelegationTokenIdentifier]( tokenId.getBytes, tokenPassword.getBytes, diff --git a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaRedactionUtilSuite.scala b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaRedactionUtilSuite.scala index 42a9fb5567b6f..225afbe5f3649 100644 --- a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaRedactionUtilSuite.scala +++ b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaRedactionUtilSuite.scala @@ -68,7 +68,7 @@ class KafkaRedactionUtilSuite extends SparkFunSuite with KafkaDelegationTokenTes test("redactParams should redact token password from parameters") { setSparkEnv(Map.empty) val groupId = "id-" + ju.UUID.randomUUID().toString - addTokenToUGI(tokenService1) + addTokenToUGI(tokenService1, tokenId1, tokenPassword1) val clusterConf = createClusterConf(identifier1, SASL_SSL.name) val jaasParams = KafkaTokenUtil.getTokenJaasParams(clusterConf) val kafkaParams = Seq( @@ -81,8 +81,8 @@ class KafkaRedactionUtilSuite extends SparkFunSuite with KafkaDelegationTokenTes assert(redactedParams.size === 2) assert(redactedParams.get(ConsumerConfig.GROUP_ID_CONFIG).get === groupId) val redactedJaasParams = redactedParams.get(SaslConfigs.SASL_JAAS_CONFIG).get - assert(redactedJaasParams.contains(tokenId)) - assert(!redactedJaasParams.contains(tokenPassword)) + assert(redactedJaasParams.contains(tokenId1)) + assert(!redactedJaasParams.contains(tokenPassword1)) } test("redactParams should redact passwords from parameters") { @@ -113,13 +113,13 @@ class KafkaRedactionUtilSuite extends SparkFunSuite with KafkaDelegationTokenTes } test("redactJaasParam should redact token password") { - addTokenToUGI(tokenService1) + addTokenToUGI(tokenService1, tokenId1, tokenPassword1) val clusterConf = createClusterConf(identifier1, SASL_SSL.name) val jaasParams = KafkaTokenUtil.getTokenJaasParams(clusterConf) val redactedJaasParams = KafkaRedactionUtil.redactJaasParam(jaasParams) - assert(redactedJaasParams.contains(tokenId)) - assert(!redactedJaasParams.contains(tokenPassword)) + assert(redactedJaasParams.contains(tokenId1)) + assert(!redactedJaasParams.contains(tokenPassword1)) } } diff --git a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaTokenUtilSuite.scala b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaTokenUtilSuite.scala index 5496195b41490..6fa1b56bff977 100644 --- a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaTokenUtilSuite.scala +++ b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaTokenUtilSuite.scala @@ -17,15 +17,18 @@ package org.apache.spark.kafka010 +import java.{util => ju} import java.security.PrivilegedExceptionAction +import scala.collection.JavaConverters._ + import org.apache.hadoop.io.Text import org.apache.hadoop.security.UserGroupInformation import org.apache.kafka.clients.CommonClientConfigs import org.apache.kafka.common.config.{SaslConfigs, SslConfigs} import org.apache.kafka.common.security.auth.SecurityProtocol.{SASL_PLAINTEXT, SASL_SSL, SSL} -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite} import org.apache.spark.internal.config._ class KafkaTokenUtilSuite extends SparkFunSuite with KafkaDelegationTokenTest { @@ -174,58 +177,102 @@ class KafkaTokenUtilSuite extends SparkFunSuite with KafkaDelegationTokenTest { assert(KafkaTokenUtil.isGlobalJaasConfigurationProvided) } - test("findMatchingToken without token should return None") { - assert(KafkaTokenUtil.findMatchingToken(sparkConf, bootStrapServers) === None) + test("findMatchingTokenClusterConfig without token should return None") { + assert(KafkaTokenUtil.findMatchingTokenClusterConfig(sparkConf, bootStrapServers) === None) } - test("findMatchingToken with non-matching tokens should return None") { + test("findMatchingTokenClusterConfig with non-matching tokens should return None") { sparkConf.set(s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers", bootStrapServers) sparkConf.set(s"spark.kafka.clusters.$identifier1.target.bootstrap.servers.regex", nonMatchingTargetServersRegex) sparkConf.set(s"spark.kafka.clusters.$identifier2.bootstrap.servers", bootStrapServers) sparkConf.set(s"spark.kafka.clusters.$identifier2.target.bootstrap.servers.regex", matchingTargetServersRegex) - addTokenToUGI(tokenService1) - addTokenToUGI(new Text("intentionally_garbage")) + addTokenToUGI(tokenService1, tokenId1, tokenPassword1) + addTokenToUGI(new Text("intentionally_garbage"), tokenId1, tokenPassword1) - assert(KafkaTokenUtil.findMatchingToken(sparkConf, bootStrapServers) === None) + assert(KafkaTokenUtil.findMatchingTokenClusterConfig(sparkConf, bootStrapServers) === None) } - test("findMatchingToken with one matching token should return cluster configuration") { + test("findMatchingTokenClusterConfig with one matching token should return token and cluster " + + "configuration") { sparkConf.set(s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers", bootStrapServers) sparkConf.set(s"spark.kafka.clusters.$identifier1.target.bootstrap.servers.regex", matchingTargetServersRegex) - addTokenToUGI(tokenService1) + addTokenToUGI(tokenService1, tokenId1, tokenPassword1) - assert(KafkaTokenUtil.findMatchingToken(sparkConf, bootStrapServers) === - Some(KafkaTokenSparkConf.getClusterConfig(sparkConf, identifier1))) + val clusterConfig = KafkaTokenUtil.findMatchingTokenClusterConfig(sparkConf, bootStrapServers) + assert(clusterConfig.get === KafkaTokenSparkConf.getClusterConfig(sparkConf, identifier1)) } - test("findMatchingToken with multiple matching tokens should throw exception") { + test("findMatchingTokenClusterConfig with multiple matching tokens should throw exception") { sparkConf.set(s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers", bootStrapServers) sparkConf.set(s"spark.kafka.clusters.$identifier1.target.bootstrap.servers.regex", matchingTargetServersRegex) sparkConf.set(s"spark.kafka.clusters.$identifier2.auth.bootstrap.servers", bootStrapServers) sparkConf.set(s"spark.kafka.clusters.$identifier2.target.bootstrap.servers.regex", matchingTargetServersRegex) - addTokenToUGI(tokenService1) - addTokenToUGI(tokenService2) + addTokenToUGI(tokenService1, tokenId1, tokenPassword1) + addTokenToUGI(tokenService2, tokenId1, tokenPassword1) val thrown = intercept[IllegalArgumentException] { - KafkaTokenUtil.findMatchingToken(sparkConf, bootStrapServers) + KafkaTokenUtil.findMatchingTokenClusterConfig(sparkConf, bootStrapServers) } assert(thrown.getMessage.contains("More than one delegation token matches")) } test("getTokenJaasParams with token should return scram module") { - addTokenToUGI(tokenService1) + addTokenToUGI(tokenService1, tokenId1, tokenPassword1) val clusterConf = createClusterConf(identifier1, SASL_SSL.name) val jaasParams = KafkaTokenUtil.getTokenJaasParams(clusterConf) assert(jaasParams.contains("ScramLoginModule required")) assert(jaasParams.contains("tokenauth=true")) - assert(jaasParams.contains(tokenId)) - assert(jaasParams.contains(tokenPassword)) + assert(jaasParams.contains(tokenId1)) + assert(jaasParams.contains(tokenPassword1)) + } + + test("isConnectorUsingCurrentToken without security should return true") { + val kafkaParams = Map[String, Object]().asJava + + assert(KafkaTokenUtil.isConnectorUsingCurrentToken(kafkaParams, None)) + } + + test("isConnectorUsingCurrentToken with same token should return true") { + setSparkEnv( + Map( + s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers" -> bootStrapServers + ) + ) + addTokenToUGI(tokenService1, tokenId1, tokenPassword1) + val kafkaParams = getKafkaParams() + val clusterConfig = KafkaTokenUtil.findMatchingTokenClusterConfig(SparkEnv.get.conf, + kafkaParams.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG).asInstanceOf[String]) + + assert(KafkaTokenUtil.isConnectorUsingCurrentToken(kafkaParams, clusterConfig)) + } + + test("isConnectorUsingCurrentToken with different token should return false") { + setSparkEnv( + Map( + s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers" -> bootStrapServers + ) + ) + addTokenToUGI(tokenService1, tokenId1, tokenPassword1) + val kafkaParams = getKafkaParams() + addTokenToUGI(tokenService1, tokenId2, tokenPassword2) + val clusterConfig = KafkaTokenUtil.findMatchingTokenClusterConfig(SparkEnv.get.conf, + kafkaParams.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG).asInstanceOf[String]) + + assert(!KafkaTokenUtil.isConnectorUsingCurrentToken(kafkaParams, clusterConfig)) + } + + private def getKafkaParams(): ju.Map[String, Object] = { + val clusterConf = createClusterConf(identifier1, SASL_SSL.name) + Map[String, Object]( + CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG -> bootStrapServers, + SaslConfigs.SASL_JAAS_CONFIG -> KafkaTokenUtil.getTokenJaasParams(clusterConf) + ).asJava } } diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 397de87d3cdff..d11569d709b23 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -45,6 +45,13 @@ ${project.version} provided + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-core_${scala.binary.version} diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 4d3e476e7cc58..925327d9d58e6 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.streaming.kafka010 import java.io.File -import java.lang.{ Long => JLong } -import java.util.{ Arrays, HashMap => JHashMap, Map => JMap, UUID } +import java.lang.{Long => JLong} +import java.util.{Arrays, HashMap => JHashMap, Map => JMap, UUID} import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicLong @@ -31,13 +31,12 @@ import scala.util.Random import org.apache.kafka.clients.consumer._ import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.serialization.StringDeserializer -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} +import org.apache.spark.streaming.{LocalStreamingContext, Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.scheduler.rate.RateEstimator @@ -45,8 +44,7 @@ import org.apache.spark.util.Utils class DirectKafkaStreamSuite extends SparkFunSuite - with BeforeAndAfter - with BeforeAndAfterAll + with LocalStreamingContext with Eventually with Logging { val sparkConf = new SparkConf() @@ -56,18 +54,17 @@ class DirectKafkaStreamSuite // Otherwise the poll timeout defaults to 2 minutes and causes test cases to run longer. .set("spark.streaming.kafka.consumer.poll.ms", "10000") - private var ssc: StreamingContext = _ private var testDir: File = _ private var kafkaTestUtils: KafkaTestUtils = _ - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() } - override def afterAll() { + override def afterAll(): Unit = { try { if (kafkaTestUtils != null) { kafkaTestUtils.teardown() @@ -78,12 +75,13 @@ class DirectKafkaStreamSuite } } - after { - if (ssc != null) { - ssc.stop(stopSparkContext = true) - } - if (testDir != null) { - Utils.deleteRecursively(testDir) + override def afterEach(): Unit = { + try { + if (testDir != null) { + Utils.deleteRecursively(testDir) + } + } finally { + super.afterEach() } } @@ -342,7 +340,7 @@ class DirectKafkaStreamSuite val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") // Send data to Kafka - def sendData(data: Seq[Int]) { + def sendData(data: Seq[Int]): Unit = { val strings = data.map { _.toString} kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap) } @@ -434,7 +432,7 @@ class DirectKafkaStreamSuite val committed = new ConcurrentHashMap[TopicPartition, OffsetAndMetadata]() // Send data to Kafka and wait for it to be received - def sendDataAndWaitForReceive(data: Seq[Int]) { + def sendDataAndWaitForReceive(data: Seq[Int]): Unit = { val strings = data.map { _.toString} kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap) eventually(timeout(10.seconds), interval(50.milliseconds)) { diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala index 431473e7f1d38..82913cf416a5f 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala @@ -27,7 +27,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.serialization.ByteArrayDeserializer import org.mockito.Mockito.when import org.scalatest.BeforeAndAfterAll -import org.scalatest.mockito.MockitoSugar +import org.scalatestplus.mockito.MockitoSugar import org.apache.spark._ diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala index 47bc8fec2c80c..d6123e16dd238 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala @@ -47,14 +47,14 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private var sc: SparkContext = _ - override def beforeAll { + override def beforeAll: Unit = { super.beforeAll() sc = new SparkContext(sparkConf) kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() } - override def afterAll { + override def afterAll: Unit = { try { try { if (sc != null) { @@ -81,7 +81,8 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private val preferredHosts = LocationStrategies.PreferConsistent - private def compactLogs(topic: String, partition: Int, messages: Array[(String, String)]) { + private def compactLogs(topic: String, partition: Int, + messages: Array[(String, String)]): Unit = { val mockTime = new MockTime() val logs = new Pool[TopicPartition, Log]() val logDir = kafkaTestUtils.brokerLogDir diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala index 5dec9709011e6..999870acfb532 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala @@ -316,7 +316,7 @@ private[kafka010] class KafkaTestUtils extends Logging { val actualPort = factory.getLocalPort - def shutdown() { + def shutdown(): Unit = { factory.shutdown() // The directories are not closed even if the ZooKeeper server is shut down. // Please see ZOOKEEPER-1844, which is fixed in 3.4.6+. It leads to test failures diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala index dedd691cd1b23..d38ed9fc9263d 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala @@ -45,7 +45,7 @@ private[kafka010] class MockTime(@volatile private var currentMs: Long) extends override def nanoseconds: Long = TimeUnit.NANOSECONDS.convert(currentMs, TimeUnit.MILLISECONDS) - override def sleep(ms: Long) { + override def sleep(ms: Long): Unit = { this.currentMs += ms scheduler.tick() } diff --git a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index 86c42df9e8435..31ca2fe5c95ff 100644 --- a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -32,13 +32,14 @@ import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.apache.spark.streaming.kinesis.KinesisUtils; +import org.apache.spark.streaming.kinesis.KinesisInitialPositions; +import org.apache.spark.streaming.kinesis.KinesisInputDStream; import scala.Tuple2; +import scala.reflect.ClassTag$; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import com.amazonaws.services.kinesis.AmazonKinesisClient; -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; /** * Consumes messages from a Amazon Kinesis streams and does wordcount. @@ -135,11 +136,19 @@ public static void main(String[] args) throws Exception { // Create the Kinesis DStreams List> streamsList = new ArrayList<>(numStreams); for (int i = 0; i < numStreams; i++) { - streamsList.add( - KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, - InitialPositionInStream.LATEST, kinesisCheckpointInterval, - StorageLevel.MEMORY_AND_DISK_2()) - ); + streamsList.add(JavaDStream.fromDStream( + KinesisInputDStream.builder() + .streamingContext(jssc) + .checkpointAppName(kinesisAppName) + .streamName(streamName) + .endpointUrl(endpointUrl) + .regionName(regionName) + .initialPosition(new KinesisInitialPositions.Latest()) + .checkpointInterval(kinesisCheckpointInterval) + .storageLevel(StorageLevel.MEMORY_AND_DISK_2()) + .build(), + ClassTag$.MODULE$.apply(byte[].class) + )); } // Union all the streams if there is more than 1 stream diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index fcb790e3ea1f9..a5d5ac769b28d 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -73,7 +73,7 @@ import org.apache.spark.streaming.kinesis.KinesisInputDStream * the Kinesis Spark Streaming integration. */ object KinesisWordCountASL extends Logging { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { // Check that all required args were passed in. if (args.length != 3) { System.err.println( @@ -178,7 +178,7 @@ object KinesisWordCountASL extends Logging { * https://kinesis.us-east-1.amazonaws.com us-east-1 10 5 */ object KinesisWordProducerASL { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length != 4) { System.err.println( """ @@ -269,7 +269,7 @@ object KinesisWordProducerASL { */ private[streaming] object StreamingExamples extends Logging { // Set reasonable logging levels for streaming if the user has not configured log4j. - def setStreamingLogLevels() { + def setStreamingLogLevels(): Unit = { val log4jInitialized = Logger.getRootLogger.getAllAppenders.hasMoreElements if (!log4jInitialized) { // We first log something to initialize Spark's default logging, then we override the diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala index 5fb83b26f8382..11e949536f2b6 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala @@ -68,7 +68,7 @@ private[kinesis] class KinesisCheckpointer( if (checkpointer != null) { try { // We must call `checkpoint()` with no parameter to finish reading shards. - // See an URL below for details: + // See a URL below for details: // https://forums.aws.amazon.com/thread.jspa?threadID=244218 KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) } catch { diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 608da0b8bf563..8c3931a1c87fd 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -19,7 +19,9 @@ package org.apache.spark.streaming.kinesis import scala.reflect.ClassTag -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import collection.JavaConverters._ +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration} +import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel import com.amazonaws.services.kinesis.model.Record import org.apache.spark.rdd.RDD @@ -43,7 +45,9 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( val messageHandler: Record => T, val kinesisCreds: SparkAWSCredentials, val dynamoDBCreds: Option[SparkAWSCredentials], - val cloudWatchCreds: Option[SparkAWSCredentials] + val cloudWatchCreds: Option[SparkAWSCredentials], + val metricsLevel: MetricsLevel, + val metricsEnabledDimensions: Set[String] ) extends ReceiverInputDStream[T](_ssc) { import KinesisReadConfigurations._ @@ -79,7 +83,8 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( override def getReceiver(): Receiver[T] = { new KinesisReceiver(streamName, endpointUrl, regionName, initialPosition, checkpointAppName, checkpointInterval, _storageLevel, messageHandler, - kinesisCreds, dynamoDBCreds, cloudWatchCreds) + kinesisCreds, dynamoDBCreds, cloudWatchCreds, + metricsLevel, metricsEnabledDimensions) } } @@ -104,6 +109,8 @@ object KinesisInputDStream { private var kinesisCredsProvider: Option[SparkAWSCredentials] = None private var dynamoDBCredsProvider: Option[SparkAWSCredentials] = None private var cloudWatchCredsProvider: Option[SparkAWSCredentials] = None + private var metricsLevel: Option[MetricsLevel] = None + private var metricsEnabledDimensions: Option[Set[String]] = None /** * Sets the StreamingContext that will be used to construct the Kinesis DStream. This is a @@ -237,6 +244,7 @@ object KinesisInputDStream { * endpoint. Defaults to [[DefaultCredentialsProvider]] if no custom value is specified. * * @param credentials [[SparkAWSCredentials]] to use for Kinesis authentication + * @return Reference to this [[KinesisInputDStream.Builder]] */ def kinesisCredentials(credentials: SparkAWSCredentials): Builder = { kinesisCredsProvider = Option(credentials) @@ -248,6 +256,7 @@ object KinesisInputDStream { * endpoint. Will use the same credentials used for AWS Kinesis if no custom value is set. * * @param credentials [[SparkAWSCredentials]] to use for DynamoDB authentication + * @return Reference to this [[KinesisInputDStream.Builder]] */ def dynamoDBCredentials(credentials: SparkAWSCredentials): Builder = { dynamoDBCredsProvider = Option(credentials) @@ -259,12 +268,43 @@ object KinesisInputDStream { * endpoint. Will use the same credentials used for AWS Kinesis if no custom value is set. * * @param credentials [[SparkAWSCredentials]] to use for CloudWatch authentication + * @return Reference to this [[KinesisInputDStream.Builder]] */ def cloudWatchCredentials(credentials: SparkAWSCredentials): Builder = { cloudWatchCredsProvider = Option(credentials) this } + /** + * Sets the CloudWatch metrics level. Defaults to + * [[KinesisClientLibConfiguration.DEFAULT_METRICS_LEVEL]] if no custom value is specified. + * + * @param metricsLevel [[MetricsLevel]] to specify the CloudWatch metrics level + * @return Reference to this [[KinesisInputDStream.Builder]] + * @see + * [[https://docs.aws.amazon.com/streams/latest/dev/monitoring-with-kcl.html#metric-levels]] + */ + def metricsLevel(metricsLevel: MetricsLevel): Builder = { + this.metricsLevel = Option(metricsLevel) + this + } + + /** + * Sets the enabled CloudWatch metrics dimensions. Defaults to + * [[KinesisClientLibConfiguration.DEFAULT_METRICS_ENABLED_DIMENSIONS]] + * if no custom value is specified. + * + * @param metricsEnabledDimensions Set[String] to specify which CloudWatch metrics dimensions + * should be enabled + * @return Reference to this [[KinesisInputDStream.Builder]] + * @see + * [[https://docs.aws.amazon.com/streams/latest/dev/monitoring-with-kcl.html#metric-levels]] + */ + def metricsEnabledDimensions(metricsEnabledDimensions: Set[String]): Builder = { + this.metricsEnabledDimensions = Option(metricsEnabledDimensions) + this + } + /** * Create a new instance of [[KinesisInputDStream]] with configured parameters and the provided * message handler. @@ -287,7 +327,9 @@ object KinesisInputDStream { ssc.sc.clean(handler), kinesisCredsProvider.getOrElse(DefaultCredentials), dynamoDBCredsProvider, - cloudWatchCredsProvider) + cloudWatchCredsProvider, + metricsLevel.getOrElse(DEFAULT_METRICS_LEVEL), + metricsEnabledDimensions.getOrElse(DEFAULT_METRICS_ENABLED_DIMENSIONS)) } /** @@ -324,4 +366,8 @@ object KinesisInputDStream { private[kinesis] val DEFAULT_KINESIS_REGION_NAME: String = "us-east-1" private[kinesis] val DEFAULT_INITIAL_POSITION: KinesisInitialPosition = new Latest() private[kinesis] val DEFAULT_STORAGE_LEVEL: StorageLevel = StorageLevel.MEMORY_AND_DISK_2 + private[kinesis] val DEFAULT_METRICS_LEVEL: MetricsLevel = + KinesisClientLibConfiguration.DEFAULT_METRICS_LEVEL + private[kinesis] val DEFAULT_METRICS_ENABLED_DIMENSIONS: Set[String] = + KinesisClientLibConfiguration.DEFAULT_METRICS_ENABLED_DIMENSIONS.asScala.toSet } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 69c52365b1bf8..6feb8f1b5598f 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -25,6 +25,7 @@ import scala.util.control.NonFatal import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory} import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{KinesisClientLibConfiguration, Worker} +import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel import com.amazonaws.services.kinesis.model.Record import org.apache.spark.internal.Logging @@ -92,7 +93,9 @@ private[kinesis] class KinesisReceiver[T]( messageHandler: Record => T, kinesisCreds: SparkAWSCredentials, dynamoDBCreds: Option[SparkAWSCredentials], - cloudWatchCreds: Option[SparkAWSCredentials]) + cloudWatchCreds: Option[SparkAWSCredentials], + metricsLevel: MetricsLevel, + metricsEnabledDimensions: Set[String]) extends Receiver[T](storageLevel) with Logging { receiver => /* @@ -143,7 +146,7 @@ private[kinesis] class KinesisReceiver[T]( * This is called when the KinesisReceiver starts and must be non-blocking. * The KCL creates and manages the receiving/processing thread pool through Worker.run(). */ - override def onStart() { + override def onStart(): Unit = { blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler) workerId = Utils.localHostName() + ":" + UUID.randomUUID() @@ -162,6 +165,8 @@ private[kinesis] class KinesisReceiver[T]( .withKinesisEndpoint(endpointUrl) .withTaskBackoffTimeMillis(500) .withRegionName(regionName) + .withMetricsLevel(metricsLevel) + .withMetricsEnabledDimensions(metricsEnabledDimensions.asJava) // Update the Kinesis client lib config with timestamp // if InitialPositionInStream.AT_TIMESTAMP is passed @@ -211,7 +216,7 @@ private[kinesis] class KinesisReceiver[T]( * The KCL worker.shutdown() method stops the receiving/processing threads. * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. */ - override def onStop() { + override def onStop(): Unit = { if (workerThread != null) { if (worker != null) { worker.shutdown() diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index 8c6a399dd763e..b35573e92e168 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -51,7 +51,7 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w * * @param shardId assigned by the KCL to this particular RecordProcessor. */ - override def initialize(shardId: String) { + override def initialize(shardId: String): Unit = { this.shardId = shardId logInfo(s"Initialized workerId $workerId with shardId $shardId") } @@ -65,7 +65,8 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w * @param checkpointer used to update Kinesis when this batch has been processed/stored * in the DStream */ - override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { + override def processRecords(batch: List[Record], + checkpointer: IRecordProcessorCheckpointer): Unit = { if (!receiver.isStopped()) { try { // Limit the number of processed records from Kinesis stream. This is because the KCL cannot diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala deleted file mode 100644 index c60b9896a3473..0000000000000 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ /dev/null @@ -1,632 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.kinesis - -import scala.reflect.ClassTag - -import com.amazonaws.regions.RegionUtils -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream -import com.amazonaws.services.kinesis.model.Record - -import org.apache.spark.api.java.function.{Function => JFunction} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Duration, StreamingContext} -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream - -object KinesisUtils { - /** - * Create an input stream that pulls messages from a Kinesis stream. - * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. - * - * @param ssc StreamingContext object - * @param kinesisAppName Kinesis application name used by the Kinesis Client Library - * (KCL) to update DynamoDB - * @param streamName Kinesis stream name - * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) - * @param regionName Name of region used by the Kinesis Client Library (KCL) to update - * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the - * worker's initial starting position in the stream. - * The values are either the beginning of the stream - * per Kinesis' limit of 24 hours - * (InitialPositionInStream.TRIM_HORIZON) or - * the tip of the stream (InitialPositionInStream.LATEST). - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. - * @param storageLevel Storage level to use for storing the received objects. - * StorageLevel.MEMORY_AND_DISK_2 is recommended. - * @param messageHandler A custom message handler that can generate a generic output from a - * Kinesis `Record`, which contains both message data, and metadata. - * - * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - */ - @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") - def createStream[T: ClassTag]( - ssc: StreamingContext, - kinesisAppName: String, - streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: InitialPositionInStream, - checkpointInterval: Duration, - storageLevel: StorageLevel, - messageHandler: Record => T): ReceiverInputDStream[T] = { - val cleanedHandler = ssc.sc.clean(messageHandler) - // Setting scope to override receiver stream's scope of "receiver stream" - ssc.withNamedScope("kinesis stream") { - new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), - KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream), - kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, DefaultCredentials, None, None) - } - } - - /** - * Create an input stream that pulls messages from a Kinesis stream. - * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. - * - * @param ssc StreamingContext object - * @param kinesisAppName Kinesis application name used by the Kinesis Client Library - * (KCL) to update DynamoDB - * @param streamName Kinesis stream name - * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) - * @param regionName Name of region used by the Kinesis Client Library (KCL) to update - * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the - * worker's initial starting position in the stream. - * The values are either the beginning of the stream - * per Kinesis' limit of 24 hours - * (InitialPositionInStream.TRIM_HORIZON) or - * the tip of the stream (InitialPositionInStream.LATEST). - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. - * @param storageLevel Storage level to use for storing the received objects. - * StorageLevel.MEMORY_AND_DISK_2 is recommended. - * @param messageHandler A custom message handler that can generate a generic output from a - * Kinesis `Record`, which contains both message data, and metadata. - * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) - * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) - * - * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - */ - // scalastyle:off - @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") - def createStream[T: ClassTag]( - ssc: StreamingContext, - kinesisAppName: String, - streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: InitialPositionInStream, - checkpointInterval: Duration, - storageLevel: StorageLevel, - messageHandler: Record => T, - awsAccessKeyId: String, - awsSecretKey: String): ReceiverInputDStream[T] = { - // scalastyle:on - val cleanedHandler = ssc.sc.clean(messageHandler) - ssc.withNamedScope("kinesis stream") { - val kinesisCredsProvider = BasicCredentials( - awsAccessKeyId = awsAccessKeyId, - awsSecretKey = awsSecretKey) - new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), - KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream), - kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, kinesisCredsProvider, None, None) - } - } - - /** - * Create an input stream that pulls messages from a Kinesis stream. - * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. - * - * @param ssc StreamingContext object - * @param kinesisAppName Kinesis application name used by the Kinesis Client Library - * (KCL) to update DynamoDB - * @param streamName Kinesis stream name - * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) - * @param regionName Name of region used by the Kinesis Client Library (KCL) to update - * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the - * worker's initial starting position in the stream. - * The values are either the beginning of the stream - * per Kinesis' limit of 24 hours - * (InitialPositionInStream.TRIM_HORIZON) or - * the tip of the stream (InitialPositionInStream.LATEST). - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. - * @param storageLevel Storage level to use for storing the received objects. - * StorageLevel.MEMORY_AND_DISK_2 is recommended. - * @param messageHandler A custom message handler that can generate a generic output from a - * Kinesis `Record`, which contains both message data, and metadata. - * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) - * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) - * @param stsAssumeRoleArn ARN of IAM role to assume when using STS sessions to read from - * Kinesis stream. - * @param stsSessionName Name to uniquely identify STS sessions if multiple principals assume - * the same role. - * @param stsExternalId External ID that can be used to validate against the assumed IAM role's - * trust policy. - * - * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - */ - // scalastyle:off - @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") - def createStream[T: ClassTag]( - ssc: StreamingContext, - kinesisAppName: String, - streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: InitialPositionInStream, - checkpointInterval: Duration, - storageLevel: StorageLevel, - messageHandler: Record => T, - awsAccessKeyId: String, - awsSecretKey: String, - stsAssumeRoleArn: String, - stsSessionName: String, - stsExternalId: String): ReceiverInputDStream[T] = { - // scalastyle:on - val cleanedHandler = ssc.sc.clean(messageHandler) - ssc.withNamedScope("kinesis stream") { - val kinesisCredsProvider = STSCredentials( - stsRoleArn = stsAssumeRoleArn, - stsSessionName = stsSessionName, - stsExternalId = Option(stsExternalId), - longLivedCreds = BasicCredentials( - awsAccessKeyId = awsAccessKeyId, - awsSecretKey = awsSecretKey)) - new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), - KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream), - kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, kinesisCredsProvider, None, None) - } - } - - /** - * Create an input stream that pulls messages from a Kinesis stream. - * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. - * - * @param ssc StreamingContext object - * @param kinesisAppName Kinesis application name used by the Kinesis Client Library - * (KCL) to update DynamoDB - * @param streamName Kinesis stream name - * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) - * @param regionName Name of region used by the Kinesis Client Library (KCL) to update - * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the - * worker's initial starting position in the stream. - * The values are either the beginning of the stream - * per Kinesis' limit of 24 hours - * (InitialPositionInStream.TRIM_HORIZON) or - * the tip of the stream (InitialPositionInStream.LATEST). - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. - * @param storageLevel Storage level to use for storing the received objects. - * StorageLevel.MEMORY_AND_DISK_2 is recommended. - * - * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - */ - @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") - def createStream( - ssc: StreamingContext, - kinesisAppName: String, - streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: InitialPositionInStream, - checkpointInterval: Duration, - storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = { - // Setting scope to override receiver stream's scope of "receiver stream" - ssc.withNamedScope("kinesis stream") { - new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), - KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream), - kinesisAppName, checkpointInterval, storageLevel, - KinesisInputDStream.defaultMessageHandler, DefaultCredentials, None, None) - } - } - - /** - * Create an input stream that pulls messages from a Kinesis stream. - * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. - * - * @param ssc StreamingContext object - * @param kinesisAppName Kinesis application name used by the Kinesis Client Library - * (KCL) to update DynamoDB - * @param streamName Kinesis stream name - * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) - * @param regionName Name of region used by the Kinesis Client Library (KCL) to update - * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the - * worker's initial starting position in the stream. - * The values are either the beginning of the stream - * per Kinesis' limit of 24 hours - * (InitialPositionInStream.TRIM_HORIZON) or - * the tip of the stream (InitialPositionInStream.LATEST). - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. - * @param storageLevel Storage level to use for storing the received objects. - * StorageLevel.MEMORY_AND_DISK_2 is recommended. - * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) - * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) - * - * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - */ - @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") - def createStream( - ssc: StreamingContext, - kinesisAppName: String, - streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: InitialPositionInStream, - checkpointInterval: Duration, - storageLevel: StorageLevel, - awsAccessKeyId: String, - awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = { - ssc.withNamedScope("kinesis stream") { - val kinesisCredsProvider = BasicCredentials( - awsAccessKeyId = awsAccessKeyId, - awsSecretKey = awsSecretKey) - new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), - KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream), - kinesisAppName, checkpointInterval, storageLevel, - KinesisInputDStream.defaultMessageHandler, kinesisCredsProvider, None, None) - } - } - - /** - * Create an input stream that pulls messages from a Kinesis stream. - * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. - * - * @param jssc Java StreamingContext object - * @param kinesisAppName Kinesis application name used by the Kinesis Client Library - * (KCL) to update DynamoDB - * @param streamName Kinesis stream name - * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) - * @param regionName Name of region used by the Kinesis Client Library (KCL) to update - * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the - * worker's initial starting position in the stream. - * The values are either the beginning of the stream - * per Kinesis' limit of 24 hours - * (InitialPositionInStream.TRIM_HORIZON) or - * the tip of the stream (InitialPositionInStream.LATEST). - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. - * @param storageLevel Storage level to use for storing the received objects. - * StorageLevel.MEMORY_AND_DISK_2 is recommended. - * @param messageHandler A custom message handler that can generate a generic output from a - * Kinesis `Record`, which contains both message data, and metadata. - * @param recordClass Class of the records in DStream - * - * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - */ - @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") - def createStream[T]( - jssc: JavaStreamingContext, - kinesisAppName: String, - streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: InitialPositionInStream, - checkpointInterval: Duration, - storageLevel: StorageLevel, - messageHandler: JFunction[Record, T], - recordClass: Class[T]): JavaReceiverInputDStream[T] = { - implicit val recordCmt: ClassTag[T] = ClassTag(recordClass) - val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_)) - createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, - initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler) - } - - /** - * Create an input stream that pulls messages from a Kinesis stream. - * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. - * - * @param jssc Java StreamingContext object - * @param kinesisAppName Kinesis application name used by the Kinesis Client Library - * (KCL) to update DynamoDB - * @param streamName Kinesis stream name - * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) - * @param regionName Name of region used by the Kinesis Client Library (KCL) to update - * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the - * worker's initial starting position in the stream. - * The values are either the beginning of the stream - * per Kinesis' limit of 24 hours - * (InitialPositionInStream.TRIM_HORIZON) or - * the tip of the stream (InitialPositionInStream.LATEST). - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. - * @param storageLevel Storage level to use for storing the received objects. - * StorageLevel.MEMORY_AND_DISK_2 is recommended. - * @param messageHandler A custom message handler that can generate a generic output from a - * Kinesis `Record`, which contains both message data, and metadata. - * @param recordClass Class of the records in DStream - * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) - * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) - * - * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - */ - // scalastyle:off - @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") - def createStream[T]( - jssc: JavaStreamingContext, - kinesisAppName: String, - streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: InitialPositionInStream, - checkpointInterval: Duration, - storageLevel: StorageLevel, - messageHandler: JFunction[Record, T], - recordClass: Class[T], - awsAccessKeyId: String, - awsSecretKey: String): JavaReceiverInputDStream[T] = { - // scalastyle:on - implicit val recordCmt: ClassTag[T] = ClassTag(recordClass) - val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_)) - createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, - initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler, - awsAccessKeyId, awsSecretKey) - } - - /** - * Create an input stream that pulls messages from a Kinesis stream. - * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. - * - * @param jssc Java StreamingContext object - * @param kinesisAppName Kinesis application name used by the Kinesis Client Library - * (KCL) to update DynamoDB - * @param streamName Kinesis stream name - * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) - * @param regionName Name of region used by the Kinesis Client Library (KCL) to update - * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the - * worker's initial starting position in the stream. - * The values are either the beginning of the stream - * per Kinesis' limit of 24 hours - * (InitialPositionInStream.TRIM_HORIZON) or - * the tip of the stream (InitialPositionInStream.LATEST). - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. - * @param storageLevel Storage level to use for storing the received objects. - * StorageLevel.MEMORY_AND_DISK_2 is recommended. - * @param messageHandler A custom message handler that can generate a generic output from a - * Kinesis `Record`, which contains both message data, and metadata. - * @param recordClass Class of the records in DStream - * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) - * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) - * @param stsAssumeRoleArn ARN of IAM role to assume when using STS sessions to read from - * Kinesis stream. - * @param stsSessionName Name to uniquely identify STS sessions if multiple princpals assume - * the same role. - * @param stsExternalId External ID that can be used to validate against the assumed IAM role's - * trust policy. - * - * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - */ - // scalastyle:off - @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") - def createStream[T]( - jssc: JavaStreamingContext, - kinesisAppName: String, - streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: InitialPositionInStream, - checkpointInterval: Duration, - storageLevel: StorageLevel, - messageHandler: JFunction[Record, T], - recordClass: Class[T], - awsAccessKeyId: String, - awsSecretKey: String, - stsAssumeRoleArn: String, - stsSessionName: String, - stsExternalId: String): JavaReceiverInputDStream[T] = { - // scalastyle:on - implicit val recordCmt: ClassTag[T] = ClassTag(recordClass) - val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_)) - createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, - initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler, - awsAccessKeyId, awsSecretKey, stsAssumeRoleArn, stsSessionName, stsExternalId) - } - - /** - * Create an input stream that pulls messages from a Kinesis stream. - * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. - * - * @param jssc Java StreamingContext object - * @param kinesisAppName Kinesis application name used by the Kinesis Client Library - * (KCL) to update DynamoDB - * @param streamName Kinesis stream name - * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) - * @param regionName Name of region used by the Kinesis Client Library (KCL) to update - * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the - * worker's initial starting position in the stream. - * The values are either the beginning of the stream - * per Kinesis' limit of 24 hours - * (InitialPositionInStream.TRIM_HORIZON) or - * the tip of the stream (InitialPositionInStream.LATEST). - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. - * @param storageLevel Storage level to use for storing the received objects. - * StorageLevel.MEMORY_AND_DISK_2 is recommended. - * - * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - */ - @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") - def createStream( - jssc: JavaStreamingContext, - kinesisAppName: String, - streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: InitialPositionInStream, - checkpointInterval: Duration, - storageLevel: StorageLevel - ): JavaReceiverInputDStream[Array[Byte]] = { - createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, - initialPositionInStream, checkpointInterval, storageLevel, - KinesisInputDStream.defaultMessageHandler(_)) - } - - /** - * Create an input stream that pulls messages from a Kinesis stream. - * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. - * - * @param jssc Java StreamingContext object - * @param kinesisAppName Kinesis application name used by the Kinesis Client Library - * (KCL) to update DynamoDB - * @param streamName Kinesis stream name - * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) - * @param regionName Name of region used by the Kinesis Client Library (KCL) to update - * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the - * worker's initial starting position in the stream. - * The values are either the beginning of the stream - * per Kinesis' limit of 24 hours - * (InitialPositionInStream.TRIM_HORIZON) or - * the tip of the stream (InitialPositionInStream.LATEST). - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. - * @param storageLevel Storage level to use for storing the received objects. - * StorageLevel.MEMORY_AND_DISK_2 is recommended. - * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) - * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) - * - * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - */ - @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") - def createStream( - jssc: JavaStreamingContext, - kinesisAppName: String, - streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: InitialPositionInStream, - checkpointInterval: Duration, - storageLevel: StorageLevel, - awsAccessKeyId: String, - awsSecretKey: String): JavaReceiverInputDStream[Array[Byte]] = { - createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, - initialPositionInStream, checkpointInterval, storageLevel, - KinesisInputDStream.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey) - } - - private def validateRegion(regionName: String): String = { - Option(RegionUtils.getRegion(regionName)).map { _.getName }.getOrElse { - throw new IllegalArgumentException(s"Region name '$regionName' is not valid") - } - } -} - -/** - * This is a helper class that wraps the methods in KinesisUtils into more Python-friendly class and - * function so that it can be easily instantiated and called from Python's KinesisUtils. - */ -private class KinesisUtilsPythonHelper { - - def getInitialPositionInStream(initialPositionInStream: Int): InitialPositionInStream = { - initialPositionInStream match { - case 0 => InitialPositionInStream.LATEST - case 1 => InitialPositionInStream.TRIM_HORIZON - case _ => throw new IllegalArgumentException( - "Illegal InitialPositionInStream. Please use " + - "InitialPositionInStream.LATEST or InitialPositionInStream.TRIM_HORIZON") - } - } - - // scalastyle:off - def createStream( - jssc: JavaStreamingContext, - kinesisAppName: String, - streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: Int, - checkpointInterval: Duration, - storageLevel: StorageLevel, - awsAccessKeyId: String, - awsSecretKey: String, - stsAssumeRoleArn: String, - stsSessionName: String, - stsExternalId: String): JavaReceiverInputDStream[Array[Byte]] = { - // scalastyle:on - if (!(stsAssumeRoleArn != null && stsSessionName != null && stsExternalId != null) - && !(stsAssumeRoleArn == null && stsSessionName == null && stsExternalId == null)) { - throw new IllegalArgumentException("stsAssumeRoleArn, stsSessionName, and stsExtenalId " + - "must all be defined or all be null") - } - - if (stsAssumeRoleArn != null && stsSessionName != null && stsExternalId != null) { - validateAwsCreds(awsAccessKeyId, awsSecretKey) - KinesisUtils.createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, - getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, - KinesisInputDStream.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey, - stsAssumeRoleArn, stsSessionName, stsExternalId) - } else { - validateAwsCreds(awsAccessKeyId, awsSecretKey) - if (awsAccessKeyId == null && awsSecretKey == null) { - KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, - getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel) - } else { - KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, - getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, - awsAccessKeyId, awsSecretKey) - } - } - } - - // Throw IllegalArgumentException unless both values are null or neither are. - private def validateAwsCreds(awsAccessKeyId: String, awsSecretKey: String) { - if (awsAccessKeyId == null && awsSecretKey != null) { - throw new IllegalArgumentException("awsSecretKey is set but awsAccessKeyId is null") - } - if (awsAccessKeyId != null && awsSecretKey == null) { - throw new IllegalArgumentException("awsAccessKeyId is set but awsSecretKey is null") - } - } -} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtilsPythonHelper.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtilsPythonHelper.scala new file mode 100644 index 0000000000000..c89dedd3366d1 --- /dev/null +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtilsPythonHelper.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} + +/** + * This is a helper class that wraps the methods in KinesisUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's KinesisUtils. + */ +private class KinesisUtilsPythonHelper { + + // scalastyle:off + def createStream( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: Int, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String, + stsAssumeRoleArn: String, + stsSessionName: String, + stsExternalId: String): JavaReceiverInputDStream[Array[Byte]] = { + // scalastyle:on + if (!(stsAssumeRoleArn != null && stsSessionName != null && stsExternalId != null) + && !(stsAssumeRoleArn == null && stsSessionName == null && stsExternalId == null)) { + throw new IllegalArgumentException("stsAssumeRoleArn, stsSessionName, and stsExtenalId " + + "must all be defined or all be null") + } + if (awsAccessKeyId == null && awsSecretKey != null) { + throw new IllegalArgumentException("awsSecretKey is set but awsAccessKeyId is null") + } + if (awsAccessKeyId != null && awsSecretKey == null) { + throw new IllegalArgumentException("awsAccessKeyId is set but awsSecretKey is null") + } + + val kinesisInitialPosition = initialPositionInStream match { + case 0 => InitialPositionInStream.LATEST + case 1 => InitialPositionInStream.TRIM_HORIZON + case _ => throw new IllegalArgumentException( + "Illegal InitialPositionInStream. Please use " + + "InitialPositionInStream.LATEST or InitialPositionInStream.TRIM_HORIZON") + } + + val builder = KinesisInputDStream.builder. + streamingContext(jssc). + checkpointAppName(kinesisAppName). + streamName(streamName). + endpointUrl(endpointUrl). + regionName(regionName). + initialPosition(KinesisInitialPositions.fromKinesisInitialPosition(kinesisInitialPosition)). + checkpointInterval(checkpointInterval). + storageLevel(storageLevel) + + if (stsAssumeRoleArn != null && stsSessionName != null && stsExternalId != null) { + val kinesisCredsProvider = STSCredentials( + stsAssumeRoleArn, stsSessionName, Option(stsExternalId), + BasicCredentials(awsAccessKeyId, awsSecretKey)) + builder. + kinesisCredentials(kinesisCredsProvider). + buildWithMessageHandler(KinesisInputDStream.defaultMessageHandler) + } else { + if (awsAccessKeyId == null && awsSecretKey == null) { + builder.build() + } else { + builder.kinesisCredentials(BasicCredentials(awsAccessKeyId, awsSecretKey)).build() + } + } + } + +} diff --git a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java deleted file mode 100644 index b37b087467926..0000000000000 --- a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.kinesis; - -import com.amazonaws.services.kinesis.model.Record; -import org.junit.Test; - -import org.apache.spark.api.java.function.Function; -import org.apache.spark.storage.StorageLevel; -import org.apache.spark.streaming.Duration; -import org.apache.spark.streaming.LocalJavaStreamingContext; -import org.apache.spark.streaming.api.java.JavaDStream; - -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; - -/** - * Demonstrate the use of the KinesisUtils Java API - */ -public class JavaKinesisStreamSuite extends LocalJavaStreamingContext { - @Test - public void testKinesisStream() { - String dummyEndpointUrl = KinesisTestUtils.defaultEndpointUrl(); - String dummyRegionName = KinesisTestUtils.getRegionNameByEndpoint(dummyEndpointUrl); - - // Tests the API, does not actually test data receiving - JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream", - dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, new Duration(2000), - StorageLevel.MEMORY_AND_DISK_2()); - ssc.stop(); - } - - @Test - public void testAwsCreds() { - String dummyEndpointUrl = KinesisTestUtils.defaultEndpointUrl(); - String dummyRegionName = KinesisTestUtils.getRegionNameByEndpoint(dummyEndpointUrl); - - // Tests the API, does not actually test data receiving - JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream", - dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, new Duration(2000), - StorageLevel.MEMORY_AND_DISK_2(), "fakeAccessKey", "fakeSecretKey"); - ssc.stop(); - } - - private static Function handler = new Function() { - @Override - public String call(Record record) { - return record.getPartitionKey() + "-" + record.getSequenceNumber(); - } - }; - - @Test - public void testCustomHandler() { - // Tests the API, does not actually test data receiving - JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST, - new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class); - - ssc.stop(); - } - - @Test - public void testCustomHandlerAwsCreds() { - // Tests the API, does not actually test data receiving - JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST, - new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class, - "fakeAccessKey", "fakeSecretKey"); - - ssc.stop(); - } - - @Test - public void testCustomHandlerAwsStsCreds() { - // Tests the API, does not actually test data receiving - JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST, - new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class, - "fakeAccessKey", "fakeSecretKey", "fakeSTSRoleArn", "fakeSTSSessionName", - "fakeSTSExternalId"); - - ssc.stop(); - } -} diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala index ac0e6a8429d06..3e88e956ec237 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala @@ -28,7 +28,7 @@ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} import org.scalatest.concurrent.Eventually -import org.scalatest.mockito.MockitoSugar +import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.streaming.{Duration, TestSuiteBase} import org.apache.spark.util.ManualClock diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala index 1c81298a7c201..8dc4de1aa3609 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala @@ -27,7 +27,7 @@ trait KinesisFunSuite extends SparkFunSuite { import KinesisTestUtils._ /** Run the test if environment variable is set or ignore the test */ - def testIfEnabled(testName: String)(testBody: => Unit) { + def testIfEnabled(testName: String)(testBody: => Unit): Unit = { if (shouldRunTests) { test(testName)(testBody) } else { diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala index 361520e292266..8b0d73c96da73 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala @@ -19,9 +19,11 @@ package org.apache.spark.streaming.kinesis import java.util.Calendar -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import collection.JavaConverters._ +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration} +import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel import org.scalatest.BeforeAndAfterEach -import org.scalatest.mockito.MockitoSugar +import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Duration, Seconds, StreamingContext, TestSuiteBase} @@ -82,6 +84,8 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE assert(dstream.kinesisCreds == DefaultCredentials) assert(dstream.dynamoDBCreds == None) assert(dstream.cloudWatchCreds == None) + assert(dstream.metricsLevel == DEFAULT_METRICS_LEVEL) + assert(dstream.metricsEnabledDimensions == DEFAULT_METRICS_ENABLED_DIMENSIONS) } test("should propagate custom non-auth values to KinesisInputDStream") { @@ -94,6 +98,9 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE val customKinesisCreds = mock[SparkAWSCredentials] val customDynamoDBCreds = mock[SparkAWSCredentials] val customCloudWatchCreds = mock[SparkAWSCredentials] + val customMetricsLevel = MetricsLevel.NONE + val customMetricsEnabledDimensions = + KinesisClientLibConfiguration.METRICS_ALWAYS_ENABLED_DIMENSIONS.asScala.toSet val dstream = builder .endpointUrl(customEndpointUrl) @@ -105,6 +112,8 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE .kinesisCredentials(customKinesisCreds) .dynamoDBCredentials(customDynamoDBCreds) .cloudWatchCredentials(customCloudWatchCreds) + .metricsLevel(customMetricsLevel) + .metricsEnabledDimensions(customMetricsEnabledDimensions) .build() assert(dstream.endpointUrl == customEndpointUrl) assert(dstream.regionName == customRegion) @@ -115,6 +124,8 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE assert(dstream.kinesisCreds == customKinesisCreds) assert(dstream.dynamoDBCreds == Option(customDynamoDBCreds)) assert(dstream.cloudWatchCreds == Option(customCloudWatchCreds)) + assert(dstream.metricsLevel == customMetricsLevel) + assert(dstream.metricsEnabledDimensions == customMetricsEnabledDimensions) // Testing with AtTimestamp val cal = Calendar.getInstance() @@ -132,6 +143,8 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE .kinesisCredentials(customKinesisCreds) .dynamoDBCredentials(customDynamoDBCreds) .cloudWatchCredentials(customCloudWatchCreds) + .metricsLevel(customMetricsLevel) + .metricsEnabledDimensions(customMetricsEnabledDimensions) .build() assert(dstreamAtTimestamp.endpointUrl == customEndpointUrl) assert(dstreamAtTimestamp.regionName == customRegion) @@ -145,6 +158,8 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE assert(dstreamAtTimestamp.kinesisCreds == customKinesisCreds) assert(dstreamAtTimestamp.dynamoDBCreds == Option(customDynamoDBCreds)) assert(dstreamAtTimestamp.cloudWatchCreds == Option(customCloudWatchCreds)) + assert(dstreamAtTimestamp.metricsLevel == customMetricsLevel) + assert(dstreamAtTimestamp.metricsEnabledDimensions == customMetricsEnabledDimensions) } test("old Api should throw UnsupportedOperationExceptionexception with AT_TIMESTAMP") { diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 52690847418ef..470a8cecc8fd9 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -27,7 +27,7 @@ import com.amazonaws.services.kinesis.model.Record import org.mockito.ArgumentMatchers.{anyList, anyString, eq => meq} import org.mockito.Mockito.{never, times, verify, when} import org.scalatest.{BeforeAndAfter, Matchers} -import org.scalatest.mockito.MockitoSugar +import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.streaming.{Duration, TestSuiteBase} diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 51ee7fd213de5..eee62d25e62bb 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -21,7 +21,6 @@ import scala.collection.mutable import scala.concurrent.duration._ import scala.util.Random -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.Matchers._ @@ -31,7 +30,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} -import org.apache.spark.streaming._ +import org.apache.spark.streaming.{LocalStreamingContext, _} import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.kinesis.KinesisInitialPositions.Latest import org.apache.spark.streaming.kinesis.KinesisReadConfigurations._ @@ -41,7 +40,7 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo import org.apache.spark.util.Utils abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFunSuite - with Eventually with BeforeAndAfter with BeforeAndAfterAll { + with LocalStreamingContext with Eventually with BeforeAndAfter with BeforeAndAfterAll { // This is the name that KCL will use to save metadata to DynamoDB private val appName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}" @@ -54,15 +53,9 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun private val dummyAWSSecretKey = "dummySecretKey" private var testUtils: KinesisTestUtils = null - private var ssc: StreamingContext = null private var sc: SparkContext = null override def beforeAll(): Unit = { - val conf = new SparkConf() - .setMaster("local[4]") - .setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name - sc = new SparkContext(conf) - runIfTestsEnabled("Prepare KinesisTestUtils") { testUtils = new KPLBasedKinesisTestUtils() testUtils.createStream() @@ -71,12 +64,6 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun override def afterAll(): Unit = { try { - if (ssc != null) { - ssc.stop() - } - if (sc != null) { - sc.stop() - } if (testUtils != null) { // Delete the Kinesis stream as well as the DynamoDB table generated by // Kinesis Client Library when consuming the stream @@ -88,34 +75,36 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun } } - before { + override def beforeEach(): Unit = { + super.beforeEach() + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name + sc = new SparkContext(conf) ssc = new StreamingContext(sc, batchDuration) } - after { - if (ssc != null) { - ssc.stop(stopSparkContext = false) - ssc = null - } - if (testUtils != null) { - testUtils.deleteDynamoDBTable(appName) + override def afterEach(): Unit = { + try { + if (testUtils != null) { + testUtils.deleteDynamoDBTable(appName) + } + } finally { + super.afterEach() } } - test("KinesisUtils API") { - val kinesisStream1 = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream", - dummyEndpointUrl, dummyRegionName, - InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2) - val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream", - dummyEndpointUrl, dummyRegionName, - InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, - dummyAWSAccessKey, dummyAWSSecretKey) - } - test("RDD generation") { - val inputStream = KinesisUtils.createStream(ssc, appName, "dummyStream", - dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, Seconds(2), - StorageLevel.MEMORY_AND_DISK_2, dummyAWSAccessKey, dummyAWSSecretKey) + val inputStream = KinesisInputDStream.builder. + streamingContext(ssc). + checkpointAppName(appName). + streamName("dummyStream"). + endpointUrl(dummyEndpointUrl). + regionName(dummyRegionName).initialPosition(new Latest()). + checkpointInterval(Seconds(2)). + storageLevel(StorageLevel.MEMORY_AND_DISK_2). + kinesisCredentials(BasicCredentials(dummyAWSAccessKey, dummyAWSSecretKey)). + build() assert(inputStream.isInstanceOf[KinesisInputDStream[Array[Byte]]]) val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream[Array[Byte]]] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Edge.scala b/graphx/src/main/scala/org/apache/spark/graphx/Edge.scala index ecc37dcaad1fe..d733868908350 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Edge.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Edge.scala @@ -81,13 +81,13 @@ object Edge { override def copyElement( src: Array[Edge[ED]], srcPos: Int, - dst: Array[Edge[ED]], dstPos: Int) { + dst: Array[Edge[ED]], dstPos: Int): Unit = { dst(dstPos) = src(srcPos) } override def copyRange( src: Array[Edge[ED]], srcPos: Int, - dst: Array[Edge[ED]], dstPos: Int, length: Int) { + dst: Array[Edge[ED]], dstPos: Int, length: Int): Unit = { System.arraycopy(src, srcPos, dst, dstPos, length) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala index ef0b943fc3c38..4ff5b02daecbe 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala @@ -30,7 +30,7 @@ object GraphXUtils { /** * Registers classes that GraphX uses with Kryo. */ - def registerKryoClasses(conf: SparkConf) { + def registerKryoClasses(conf: SparkConf): Unit = { conf.registerKryoClasses(Array( classOf[Edge[Object]], classOf[(VertexId, Object)], @@ -54,7 +54,7 @@ object GraphXUtils { mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], reduceFunc: (A, A) => A, activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = { - def sendMsg(ctx: EdgeContext[VD, ED, A]) { + def sendMsg(ctx: EdgeContext[VD, ED, A]): Unit = { mapFunc(ctx.toEdgeTriplet).foreach { kv => val id = kv._1 val msg = kv._2 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index 0e6a340a680ba..8d03112a1c3dc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -222,7 +222,7 @@ class EdgePartition[ * * @param f an external state mutating user defined function. */ - def foreach(f: Edge[ED] => Unit) { + def foreach(f: Edge[ED] => Unit): Unit = { iterator.foreach(f) } @@ -495,7 +495,7 @@ private class AggregatingEdgeContext[VD, ED, A]( srcId: VertexId, dstId: VertexId, localSrcId: Int, localDstId: Int, srcAttr: VD, dstAttr: VD, - attr: ED) { + attr: ED): Unit = { _srcId = srcId _dstId = dstId _localSrcId = localSrcId @@ -505,13 +505,13 @@ private class AggregatingEdgeContext[VD, ED, A]( _attr = attr } - def setSrcOnly(srcId: VertexId, localSrcId: Int, srcAttr: VD) { + def setSrcOnly(srcId: VertexId, localSrcId: Int, srcAttr: VD): Unit = { _srcId = srcId _localSrcId = localSrcId _srcAttr = srcAttr } - def setRest(dstId: VertexId, localDstId: Int, dstAttr: VD, attr: ED) { + def setRest(dstId: VertexId, localDstId: Int, dstAttr: VD, attr: ED): Unit = { _dstId = dstId _localDstId = localDstId _dstAttr = dstAttr @@ -524,14 +524,14 @@ private class AggregatingEdgeContext[VD, ED, A]( override def dstAttr: VD = _dstAttr override def attr: ED = _attr - override def sendToSrc(msg: A) { + override def sendToSrc(msg: A): Unit = { send(_localSrcId, msg) } - override def sendToDst(msg: A) { + override def sendToDst(msg: A): Unit = { send(_localDstId, msg) } - @inline private def send(localId: Int, msg: A) { + @inline private def send(localId: Int, msg: A): Unit = { if (bitset.get(localId)) { aggregates(localId) = mergeMsg(aggregates(localId), msg) } else { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 27c08c894a39f..c7868f85d1f76 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -30,7 +30,7 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla private[this] val edges = new PrimitiveVector[Edge[ED]](size) /** Add a new edge to the partition. */ - def add(src: VertexId, dst: VertexId, d: ED) { + def add(src: VertexId, dst: VertexId, d: ED): Unit = { edges += Edge(src, dst, d) } @@ -90,7 +90,7 @@ class ExistingEdgePartitionBuilder[ private[this] val edges = new PrimitiveVector[EdgeWithLocalIds[ED]](size) /** Add a new edge to the partition. */ - def add(src: VertexId, dst: VertexId, localSrc: Int, localDst: Int, d: ED) { + def add(src: VertexId, dst: VertexId, localSrc: Int, localDst: Int, d: ED): Unit = { edges += EdgeWithLocalIds(src, dst, localSrc, localDst, d) } @@ -153,13 +153,13 @@ private[impl] object EdgeWithLocalIds { override def copyElement( src: Array[EdgeWithLocalIds[ED]], srcPos: Int, - dst: Array[EdgeWithLocalIds[ED]], dstPos: Int) { + dst: Array[EdgeWithLocalIds[ED]], dstPos: Int): Unit = { dst(dstPos) = src(srcPos) } override def copyRange( src: Array[EdgeWithLocalIds[ED]], srcPos: Int, - dst: Array[EdgeWithLocalIds[ED]], dstPos: Int, length: Int) { + dst: Array[EdgeWithLocalIds[ED]], dstPos: Int, length: Int): Unit = { System.arraycopy(src, srcPos, dst, dstPos, length) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 0a97ab492600d..8564597f4f135 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -103,15 +103,16 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( (part, (e.srcId, e.dstId, e.attr)) } .partitionBy(new HashPartitioner(numPartitions)) - .mapPartitionsWithIndex( { (pid, iter) => - val builder = new EdgePartitionBuilder[ED, VD]()(edTag, vdTag) - iter.foreach { message => - val data = message._2 - builder.add(data._1, data._2, data._3) - } - val edgePartition = builder.toEdgePartition - Iterator((pid, edgePartition)) - }, preservesPartitioning = true)).cache() + .mapPartitionsWithIndex( + { (pid: Int, iter: Iterator[(PartitionID, (VertexId, VertexId, ED))]) => + val builder = new EdgePartitionBuilder[ED, VD]()(edTag, vdTag) + iter.foreach { message => + val data = message._2 + builder.add(data._1, data._2, data._3) + } + val edgePartition = builder.toEdgePartition + Iterator((pid, edgePartition)) + }, preservesPartitioning = true)).cache() GraphImpl.fromExistingRDDs(vertices.withEdges(newEdges), newEdges) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index d2194d85bf525..e0d4dd3248734 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -58,7 +58,7 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( * `vertices`. This operation modifies the `ReplicatedVertexView`, and callers can access `edges` * afterwards to obtain the upgraded view. */ - def upgrade(vertices: VertexRDD[VD], includeSrc: Boolean, includeDst: Boolean) { + def upgrade(vertices: VertexRDD[VD], includeSrc: Boolean, includeDst: Boolean): Unit = { val shipSrc = includeSrc && !hasSrcId val shipDst = includeDst && !hasDstId if (shipSrc || shipDst) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index 6453bbeae9f10..bef380dc12c23 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -123,7 +123,7 @@ class RoutingTablePartition( */ def foreachWithinEdgePartition (pid: PartitionID, includeSrc: Boolean, includeDst: Boolean) - (f: VertexId => Unit) { + (f: VertexId => Unit): Unit = { val (vidsCandidate, srcVids, dstVids) = routingTable(pid) val size = vidsCandidate.length if (includeSrc && includeDst) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index 2847a4e172d40..c508056fe3ae3 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -98,7 +98,7 @@ object SVDPlusPlus { (ctx: EdgeContext[ (Array[Double], Array[Double], Double, Double), Double, - (Array[Double], Array[Double], Double)]) { + (Array[Double], Array[Double], Double)]): Unit = { val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) val rank = p.length @@ -177,7 +177,7 @@ object SVDPlusPlus { // calculate error on training set def sendMsgTestF(conf: Conf, u: Double) - (ctx: EdgeContext[(Array[Double], Array[Double], Double, Double), Double, Double]) { + (ctx: EdgeContext[(Array[Double], Array[Double], Double, Double), Double, Double]): Unit = { val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) var pred = u + usr._3 + itm._3 + blas.ddot(q.length, q, 1, usr._2, 1) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala index 2715137d19ebc..211b4d6e4c5d3 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala @@ -85,7 +85,7 @@ object TriangleCount { } // Edge function computes intersection of smaller vertex with larger vertex - def edgeFunc(ctx: EdgeContext[VertexSet, ED, Int]) { + def edgeFunc(ctx: EdgeContext[VertexSet, ED, Int]): Unit = { val (smallSet, largeSet) = if (ctx.srcAttr.size < ctx.dstAttr.size) { (ctx.srcAttr, ctx.dstAttr) } else { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index 5ece5ae5c359b..dc3cdc452a389 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -118,7 +118,7 @@ private[graphx] object BytecodeUtils { if (name == methodName) { new MethodVisitor(ASM7) { override def visitMethodInsn( - op: Int, owner: String, name: String, desc: String, itf: Boolean) { + op: Int, owner: String, name: String, desc: String, itf: Boolean): Unit = { if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { if (!skipClass(owner)) { methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name)) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala index 972237da1cb28..e3b283649cb2c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala @@ -71,7 +71,7 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } /** Set the value for a key */ - def update(k: K, v: V) { + def update(k: K, v: V): Unit = { val pos = keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK _values(pos) = v keySet.rehashIfNeeded(k, grow, move) @@ -80,7 +80,7 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, /** Set the value for a key */ - def setMerge(k: K, v: V, mergeF: (V, V) => V) { + def setMerge(k: K, v: V, mergeF: (V, V) => V): Unit = { val pos = keySet.addWithoutResize(k) val ind = pos & OpenHashSet.POSITION_MASK if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { // if first add diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index 84940d96b563f..32844104c1deb 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -26,8 +26,11 @@ import java.util.Map; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicInteger; +import java.util.logging.Level; +import java.util.logging.Logger; import static org.apache.spark.launcher.CommandBuilderUtils.*; +import static org.apache.spark.launcher.CommandBuilderUtils.join; /** * Launcher for Spark applications. @@ -38,6 +41,8 @@ */ public class SparkLauncher extends AbstractLauncher { + private static final Logger LOG = Logger.getLogger(SparkLauncher.class.getName()); + /** The Spark master. */ public static final String SPARK_MASTER = "spark.master"; @@ -363,6 +368,9 @@ public SparkAppHandle startApplication(SparkAppHandle.Listener... listeners) thr String loggerName = getLoggerName(); ProcessBuilder pb = createBuilder(); + if (LOG.isLoggable(Level.FINE)) { + LOG.fine(String.format("Launching Spark application:%n%s", join(" ", pb.command()))); + } boolean outputToLog = outputStream == null; boolean errorToLog = !redirectErrorStream && errorStream == null; diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 3479e0c3422bd..383c3f60a595b 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -348,7 +348,7 @@ private List buildPySparkShellCommand(Map env) throws IO } private List buildSparkRCommand(Map env) throws IOException { - if (!appArgs.isEmpty() && appArgs.get(0).endsWith(".R")) { + if (!appArgs.isEmpty() && (appArgs.get(0).endsWith(".R") || appArgs.get(0).endsWith(".r"))) { System.err.println( "Running R applications through 'sparkR' is not supported as of Spark 2.0.\n" + "Use ./bin/spark-submit "); @@ -390,9 +390,7 @@ boolean isClientMode(Map userProps) { String userMaster = firstNonEmpty(master, userProps.get(SparkLauncher.SPARK_MASTER)); String userDeployMode = firstNonEmpty(deployMode, userProps.get(SparkLauncher.DEPLOY_MODE)); // Default master is "local[*]", so assume client mode in that case - return userMaster == null || - "client".equals(userDeployMode) || - (!userMaster.equals("yarn-cluster") && userDeployMode == null); + return userMaster == null || userDeployMode == null || "client".equals(userDeployMode); } /** diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 32a91b1789412..752e8d4c23f8b 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -250,6 +250,26 @@ public void testMissingAppResource() { new SparkSubmitCommandBuilder().buildSparkSubmitArgs(); } + @Test + public void testIsClientMode() { + // Default master is "local[*]" + SparkSubmitCommandBuilder builder = newCommandBuilder(Collections.emptyList()); + assertTrue("By default application run in local mode", + builder.isClientMode(Collections.emptyMap())); + // --master yarn or it can be any RM + List sparkSubmitArgs = Arrays.asList(parser.MASTER, "yarn"); + builder = newCommandBuilder(sparkSubmitArgs); + assertTrue("By default deploy mode is client", builder.isClientMode(Collections.emptyMap())); + // --master yarn and set spark.submit.deployMode to client + Map userProps = new HashMap<>(); + userProps.put("spark.submit.deployMode", "client"); + assertTrue(builder.isClientMode(userProps)); + // --master mesos --deploy-mode cluster + sparkSubmitArgs = Arrays.asList(parser.MASTER, "mesos", parser.DEPLOY_MODE, "cluster"); + builder = newCommandBuilder(sparkSubmitArgs); + assertFalse(builder.isClientMode(Collections.emptyMap())); + } + private void testCmdBuilder(boolean isDriver, boolean useDefaultPropertyFile) throws Exception { final String DRIVER_DEFAULT_PARAM = "-Ddriver-default"; final String DRIVER_EXTRA_PARAM = "-Ddriver-extra"; diff --git a/licenses-binary/LICENSE-JLargeArrays.txt b/licenses-binary/LICENSE-JLargeArrays.txt new file mode 100644 index 0000000000000..304e724556984 --- /dev/null +++ b/licenses-binary/LICENSE-JLargeArrays.txt @@ -0,0 +1,23 @@ +JLargeArrays +Copyright (C) 2013 onward University of Warsaw, ICM +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-JTransforms.txt b/licenses-binary/LICENSE-JTransforms.txt new file mode 100644 index 0000000000000..2f0589f76da7d --- /dev/null +++ b/licenses-binary/LICENSE-JTransforms.txt @@ -0,0 +1,23 @@ +JTransforms +Copyright (c) 2007 onward, Piotr Wendykier +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-dnsjava.txt b/licenses-binary/LICENSE-dnsjava.txt new file mode 100644 index 0000000000000..70ee5b12ff23f --- /dev/null +++ b/licenses-binary/LICENSE-dnsjava.txt @@ -0,0 +1,24 @@ +Copyright (c) 1998-2011, Brian Wellington. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-jtransforms.html b/licenses-binary/LICENSE-jtransforms.html deleted file mode 100644 index 351c17412357b..0000000000000 --- a/licenses-binary/LICENSE-jtransforms.html +++ /dev/null @@ -1,388 +0,0 @@ - - -Mozilla Public License version 1.1 - - - - -

Mozilla Public License Version 1.1

-

1. Definitions.

-
-
1.0.1. "Commercial Use" -
means distribution or otherwise making the Covered Code available to a third party. -
1.1. "Contributor" -
means each entity that creates or contributes to the creation of Modifications. -
1.2. "Contributor Version" -
means the combination of the Original Code, prior Modifications used by a Contributor, - and the Modifications made by that particular Contributor. -
1.3. "Covered Code" -
means the Original Code or Modifications or the combination of the Original Code and - Modifications, in each case including portions thereof. -
1.4. "Electronic Distribution Mechanism" -
means a mechanism generally accepted in the software development community for the - electronic transfer of data. -
1.5. "Executable" -
means Covered Code in any form other than Source Code. -
1.6. "Initial Developer" -
means the individual or entity identified as the Initial Developer in the Source Code - notice required by Exhibit A. -
1.7. "Larger Work" -
means a work which combines Covered Code or portions thereof with code not governed - by the terms of this License. -
1.8. "License" -
means this document. -
1.8.1. "Licensable" -
means having the right to grant, to the maximum extent possible, whether at the - time of the initial grant or subsequently acquired, any and all of the rights - conveyed herein. -
1.9. "Modifications" -
-

means any addition to or deletion from the substance or structure of either the - Original Code or any previous Modifications. When Covered Code is released as a - series of files, a Modification is: -

    -
  1. Any addition to or deletion from the contents of a file - containing Original Code or previous Modifications. -
  2. Any new file that contains any part of the Original Code or - previous Modifications. -
-
1.10. "Original Code" -
means Source Code of computer software code which is described in the Source Code - notice required by Exhibit A as Original Code, and which, - at the time of its release under this License is not already Covered Code governed - by this License. -
1.10.1. "Patent Claims" -
means any patent claim(s), now owned or hereafter acquired, including without - limitation, method, process, and apparatus claims, in any patent Licensable by - grantor. -
1.11. "Source Code" -
means the preferred form of the Covered Code for making modifications to it, - including all modules it contains, plus any associated interface definition files, - scripts used to control compilation and installation of an Executable, or source - code differential comparisons against either the Original Code or another well known, - available Covered Code of the Contributor's choice. The Source Code can be in a - compressed or archival form, provided the appropriate decompression or de-archiving - software is widely available for no charge. -
1.12. "You" (or "Your") -
means an individual or a legal entity exercising rights under, and complying with - all of the terms of, this License or a future version of this License issued under - Section 6.1. For legal entities, "You" includes any entity - which controls, is controlled by, or is under common control with You. For purposes of - this definition, "control" means (a) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or otherwise, or (b) - ownership of more than fifty percent (50%) of the outstanding shares or beneficial - ownership of such entity. -
-

2. Source Code License.

-

2.1. The Initial Developer Grant.

-

The Initial Developer hereby grants You a world-wide, royalty-free, non-exclusive - license, subject to third party intellectual property claims: -

    -
  1. under intellectual property rights (other than patent or - trademark) Licensable by Initial Developer to use, reproduce, modify, display, perform, - sublicense and distribute the Original Code (or portions thereof) with or without - Modifications, and/or as part of a Larger Work; and -
  2. under Patents Claims infringed by the making, using or selling - of Original Code, to make, have made, use, practice, sell, and offer for sale, and/or - otherwise dispose of the Original Code (or portions thereof). -
  3. the licenses granted in this Section 2.1 - (a) and (b) are effective on - the date Initial Developer first distributes Original Code under the terms of this - License. -
  4. Notwithstanding Section 2.1 (b) - above, no patent license is granted: 1) for code that You delete from the Original Code; - 2) separate from the Original Code; or 3) for infringements caused by: i) the - modification of the Original Code or ii) the combination of the Original Code with other - software or devices. -
-

2.2. Contributor Grant.

-

Subject to third party intellectual property claims, each Contributor hereby grants You - a world-wide, royalty-free, non-exclusive license -

    -
  1. under intellectual property rights (other than patent or trademark) - Licensable by Contributor, to use, reproduce, modify, display, perform, sublicense and - distribute the Modifications created by such Contributor (or portions thereof) either on - an unmodified basis, with other Modifications, as Covered Code and/or as part of a Larger - Work; and -
  2. under Patent Claims infringed by the making, using, or selling of - Modifications made by that Contributor either alone and/or in combination with its - Contributor Version (or portions of such combination), to make, use, sell, offer for - sale, have made, and/or otherwise dispose of: 1) Modifications made by that Contributor - (or portions thereof); and 2) the combination of Modifications made by that Contributor - with its Contributor Version (or portions of such combination). -
  3. the licenses granted in Sections 2.2 - (a) and 2.2 (b) are effective - on the date Contributor first makes Commercial Use of the Covered Code. -
  4. Notwithstanding Section 2.2 (b) - above, no patent license is granted: 1) for any code that Contributor has deleted from - the Contributor Version; 2) separate from the Contributor Version; 3) for infringements - caused by: i) third party modifications of Contributor Version or ii) the combination of - Modifications made by that Contributor with other software (except as part of the - Contributor Version) or other devices; or 4) under Patent Claims infringed by Covered Code - in the absence of Modifications made by that Contributor. -
-

3. Distribution Obligations.

-

3.1. Application of License.

-

The Modifications which You create or to which You contribute are governed by the terms - of this License, including without limitation Section 2.2. The - Source Code version of Covered Code may be distributed only under the terms of this License - or a future version of this License released under Section 6.1, - and You must include a copy of this License with every copy of the Source Code You - distribute. You may not offer or impose any terms on any Source Code version that alters or - restricts the applicable version of this License or the recipients' rights hereunder. - However, You may include an additional document offering the additional rights described in - Section 3.5. -

3.2. Availability of Source Code.

-

Any Modification which You create or to which You contribute must be made available in - Source Code form under the terms of this License either on the same media as an Executable - version or via an accepted Electronic Distribution Mechanism to anyone to whom you made an - Executable version available; and if made available via Electronic Distribution Mechanism, - must remain available for at least twelve (12) months after the date it initially became - available, or at least six (6) months after a subsequent version of that particular - Modification has been made available to such recipients. You are responsible for ensuring - that the Source Code version remains available even if the Electronic Distribution - Mechanism is maintained by a third party. -

3.3. Description of Modifications.

-

You must cause all Covered Code to which You contribute to contain a file documenting the - changes You made to create that Covered Code and the date of any change. You must include a - prominent statement that the Modification is derived, directly or indirectly, from Original - Code provided by the Initial Developer and including the name of the Initial Developer in - (a) the Source Code, and (b) in any notice in an Executable version or related documentation - in which You describe the origin or ownership of the Covered Code. -

3.4. Intellectual Property Matters

-

(a) Third Party Claims

-

If Contributor has knowledge that a license under a third party's intellectual property - rights is required to exercise the rights granted by such Contributor under Sections - 2.1 or 2.2, Contributor must include a - text file with the Source Code distribution titled "LEGAL" which describes the claim and the - party making the claim in sufficient detail that a recipient will know whom to contact. If - Contributor obtains such knowledge after the Modification is made available as described in - Section 3.2, Contributor shall promptly modify the LEGAL file in - all copies Contributor makes available thereafter and shall take other steps (such as - notifying appropriate mailing lists or newsgroups) reasonably calculated to inform those who - received the Covered Code that new knowledge has been obtained. -

(b) Contributor APIs

-

If Contributor's Modifications include an application programming interface and Contributor - has knowledge of patent licenses which are reasonably necessary to implement that - API, Contributor must also include this information in the - legal file. -

(c) Representations.

-

Contributor represents that, except as disclosed pursuant to Section 3.4 - (a) above, Contributor believes that Contributor's Modifications - are Contributor's original creation(s) and/or Contributor has sufficient rights to grant the - rights conveyed by this License. -

3.5. Required Notices.

-

You must duplicate the notice in Exhibit A in each file of the - Source Code. If it is not possible to put such notice in a particular Source Code file due to - its structure, then You must include such notice in a location (such as a relevant directory) - where a user would be likely to look for such a notice. If You created one or more - Modification(s) You may add your name as a Contributor to the notice described in - Exhibit A. You must also duplicate this License in any documentation - for the Source Code where You describe recipients' rights or ownership rights relating to - Covered Code. You may choose to offer, and to charge a fee for, warranty, support, indemnity - or liability obligations to one or more recipients of Covered Code. However, You may do so - only on Your own behalf, and not on behalf of the Initial Developer or any Contributor. You - must make it absolutely clear than any such warranty, support, indemnity or liability - obligation is offered by You alone, and You hereby agree to indemnify the Initial Developer - and every Contributor for any liability incurred by the Initial Developer or such Contributor - as a result of warranty, support, indemnity or liability terms You offer. -

3.6. Distribution of Executable Versions.

-

You may distribute Covered Code in Executable form only if the requirements of Sections - 3.1, 3.2, - 3.3, 3.4 and - 3.5 have been met for that Covered Code, and if You include a - notice stating that the Source Code version of the Covered Code is available under the terms - of this License, including a description of how and where You have fulfilled the obligations - of Section 3.2. The notice must be conspicuously included in any - notice in an Executable version, related documentation or collateral in which You describe - recipients' rights relating to the Covered Code. You may distribute the Executable version of - Covered Code or ownership rights under a license of Your choice, which may contain terms - different from this License, provided that You are in compliance with the terms of this - License and that the license for the Executable version does not attempt to limit or alter the - recipient's rights in the Source Code version from the rights set forth in this License. If - You distribute the Executable version under a different license You must make it absolutely - clear that any terms which differ from this License are offered by You alone, not by the - Initial Developer or any Contributor. You hereby agree to indemnify the Initial Developer and - every Contributor for any liability incurred by the Initial Developer or such Contributor as - a result of any such terms You offer. -

3.7. Larger Works.

-

You may create a Larger Work by combining Covered Code with other code not governed by the - terms of this License and distribute the Larger Work as a single product. In such a case, - You must make sure the requirements of this License are fulfilled for the Covered Code. -

4. Inability to Comply Due to Statute or Regulation.

-

If it is impossible for You to comply with any of the terms of this License with respect to - some or all of the Covered Code due to statute, judicial order, or regulation then You must: - (a) comply with the terms of this License to the maximum extent possible; and (b) describe - the limitations and the code they affect. Such description must be included in the - legal file described in Section - 3.4 and must be included with all distributions of the Source Code. - Except to the extent prohibited by statute or regulation, such description must be - sufficiently detailed for a recipient of ordinary skill to be able to understand it. -

5. Application of this License.

-

This License applies to code to which the Initial Developer has attached the notice in - Exhibit A and to related Covered Code. -

6. Versions of the License.

-

6.1. New Versions

-

Netscape Communications Corporation ("Netscape") may publish revised and/or new versions - of the License from time to time. Each version will be given a distinguishing version number. -

6.2. Effect of New Versions

-

Once Covered Code has been published under a particular version of the License, You may - always continue to use it under the terms of that version. You may also choose to use such - Covered Code under the terms of any subsequent version of the License published by Netscape. - No one other than Netscape has the right to modify the terms applicable to Covered Code - created under this License. -

6.3. Derivative Works

-

If You create or use a modified version of this License (which you may only do in order to - apply it to code which is not already Covered Code governed by this License), You must (a) - rename Your license so that the phrases "Mozilla", "MOZILLAPL", "MOZPL", "Netscape", "MPL", - "NPL" or any confusingly similar phrase do not appear in your license (except to note that - your license differs from this License) and (b) otherwise make it clear that Your version of - the license contains terms which differ from the Mozilla Public License and Netscape Public - License. (Filling in the name of the Initial Developer, Original Code or Contributor in the - notice described in Exhibit A shall not of themselves be deemed to - be modifications of this License.) -

7. Disclaimer of warranty

-

Covered code is provided under this license on an "as is" - basis, without warranty of any kind, either expressed or implied, including, without - limitation, warranties that the covered code is free of defects, merchantable, fit for a - particular purpose or non-infringing. The entire risk as to the quality and performance of - the covered code is with you. Should any covered code prove defective in any respect, you - (not the initial developer or any other contributor) assume the cost of any necessary - servicing, repair or correction. This disclaimer of warranty constitutes an essential part - of this license. No use of any covered code is authorized hereunder except under this - disclaimer. -

8. Termination

-

8.1. This License and the rights granted hereunder will terminate - automatically if You fail to comply with terms herein and fail to cure such breach - within 30 days of becoming aware of the breach. All sublicenses to the Covered Code which - are properly granted shall survive any termination of this License. Provisions which, by - their nature, must remain in effect beyond the termination of this License shall survive. -

8.2. If You initiate litigation by asserting a patent infringement - claim (excluding declatory judgment actions) against Initial Developer or a Contributor - (the Initial Developer or Contributor against whom You file such action is referred to - as "Participant") alleging that: -

    -
  1. such Participant's Contributor Version directly or indirectly - infringes any patent, then any and all rights granted by such Participant to You under - Sections 2.1 and/or 2.2 of this - License shall, upon 60 days notice from Participant terminate prospectively, unless if - within 60 days after receipt of notice You either: (i) agree in writing to pay - Participant a mutually agreeable reasonable royalty for Your past and future use of - Modifications made by such Participant, or (ii) withdraw Your litigation claim with - respect to the Contributor Version against such Participant. If within 60 days of - notice, a reasonable royalty and payment arrangement are not mutually agreed upon in - writing by the parties or the litigation claim is not withdrawn, the rights granted by - Participant to You under Sections 2.1 and/or - 2.2 automatically terminate at the expiration of the 60 day - notice period specified above. -
  2. any software, hardware, or device, other than such Participant's - Contributor Version, directly or indirectly infringes any patent, then any rights - granted to You by such Participant under Sections 2.1(b) - and 2.2(b) are revoked effective as of the date You first - made, used, sold, distributed, or had made, Modifications made by that Participant. -
-

8.3. If You assert a patent infringement claim against Participant - alleging that such Participant's Contributor Version directly or indirectly infringes - any patent where such claim is resolved (such as by license or settlement) prior to the - initiation of patent infringement litigation, then the reasonable value of the licenses - granted by such Participant under Sections 2.1 or - 2.2 shall be taken into account in determining the amount or - value of any payment or license. -

8.4. In the event of termination under Sections - 8.1 or 8.2 above, all end user - license agreements (excluding distributors and resellers) which have been validly - granted by You or any distributor hereunder prior to termination shall survive - termination. -

9. Limitation of liability

-

Under no circumstances and under no legal theory, whether - tort (including negligence), contract, or otherwise, shall you, the initial developer, - any other contributor, or any distributor of covered code, or any supplier of any of - such parties, be liable to any person for any indirect, special, incidental, or - consequential damages of any character including, without limitation, damages for loss - of goodwill, work stoppage, computer failure or malfunction, or any and all other - commercial damages or losses, even if such party shall have been informed of the - possibility of such damages. This limitation of liability shall not apply to liability - for death or personal injury resulting from such party's negligence to the extent - applicable law prohibits such limitation. Some jurisdictions do not allow the exclusion - or limitation of incidental or consequential damages, so this exclusion and limitation - may not apply to you. -

10. U.S. government end users

-

The Covered Code is a "commercial item," as that term is defined in 48 - C.F.R. 2.101 (Oct. 1995), consisting of - "commercial computer software" and "commercial computer software documentation," as such - terms are used in 48 C.F.R. 12.212 (Sept. - 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R. - 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users - acquire Covered Code with only those rights set forth herein. -

11. Miscellaneous

-

This License represents the complete agreement concerning subject matter hereof. If - any provision of this License is held to be unenforceable, such provision shall be - reformed only to the extent necessary to make it enforceable. This License shall be - governed by California law provisions (except to the extent applicable law, if any, - provides otherwise), excluding its conflict-of-law provisions. With respect to - disputes in which at least one party is a citizen of, or an entity chartered or - registered to do business in the United States of America, any litigation relating to - this License shall be subject to the jurisdiction of the Federal Courts of the - Northern District of California, with venue lying in Santa Clara County, California, - with the losing party responsible for costs, including without limitation, court - costs and reasonable attorneys' fees and expenses. The application of the United - Nations Convention on Contracts for the International Sale of Goods is expressly - excluded. Any law or regulation which provides that the language of a contract - shall be construed against the drafter shall not apply to this License. -

12. Responsibility for claims

-

As between Initial Developer and the Contributors, each party is responsible for - claims and damages arising, directly or indirectly, out of its utilization of rights - under this License and You agree to work with Initial Developer and Contributors to - distribute such responsibility on an equitable basis. Nothing herein is intended or - shall be deemed to constitute any admission of liability. -

13. Multiple-licensed code

-

Initial Developer may designate portions of the Covered Code as - "Multiple-Licensed". "Multiple-Licensed" means that the Initial Developer permits - you to utilize portions of the Covered Code under Your choice of the MPL - or the alternative licenses, if any, specified by the Initial Developer in the file - described in Exhibit A. -

Exhibit A - Mozilla Public License.

-
"The contents of this file are subject to the Mozilla Public License
-Version 1.1 (the "License"); you may not use this file except in
-compliance with the License. You may obtain a copy of the License at
-http://www.mozilla.org/MPL/
-
-Software distributed under the License is distributed on an "AS IS"
-basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the
-License for the specific language governing rights and limitations
-under the License.
-
-The Original Code is JTransforms.
-
-The Initial Developer of the Original Code is
-Piotr Wendykier, Emory University.
-Portions created by the Initial Developer are Copyright (C) 2007-2009
-the Initial Developer. All Rights Reserved.
-
-Alternatively, the contents of this file may be used under the terms of
-either the GNU General Public License Version 2 or later (the "GPL"), or
-the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
-in which case the provisions of the GPL or the LGPL are applicable instead
-of those above. If you wish to allow use of your version of this file only
-under the terms of either the GPL or the LGPL, and not to allow others to
-use your version of this file under the terms of the MPL, indicate your
-decision by deleting the provisions above and replace them with the notice
-and other provisions required by the GPL or the LGPL. If you do not delete
-the provisions above, a recipient may use your version of this file under
-the terms of any one of the MPL, the GPL or the LGPL.
-

NOTE: The text of this Exhibit A may differ slightly from the text of - the notices in the Source Code files of the Original Code. You should - use the text of this Exhibit A rather than the text found in the - Original Code Source Code for Your Modifications. - -

\ No newline at end of file diff --git a/licenses-binary/LICENSE-re2j.txt b/licenses-binary/LICENSE-re2j.txt new file mode 100644 index 0000000000000..0dc3cd70bf1f7 --- /dev/null +++ b/licenses-binary/LICENSE-re2j.txt @@ -0,0 +1,32 @@ +This is a work derived from Russ Cox's RE2 in Go, whose license +http://golang.org/LICENSE is as follows: + +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + + * Neither the name of Google Inc. nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala index 2a0f8c11d0a50..e054a15fc9b75 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala @@ -302,7 +302,7 @@ private[spark] object BLAS extends Serializable { * @param x the vector x that contains the n elements. * @param A the symmetric matrix A. Size of n x n. */ - def syr(alpha: Double, x: Vector, A: DenseMatrix) { + def syr(alpha: Double, x: Vector, A: DenseMatrix): Unit = { val mA = A.numRows val nA = A.numCols require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA") @@ -316,7 +316,7 @@ private[spark] object BLAS extends Serializable { } } - private def syr(alpha: Double, x: DenseVector, A: DenseMatrix) { + private def syr(alpha: Double, x: DenseVector, A: DenseMatrix): Unit = { val nA = A.numRows val mA = A.numCols @@ -334,7 +334,7 @@ private[spark] object BLAS extends Serializable { } } - private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) { + private def syr(alpha: Double, x: SparseVector, A: DenseMatrix): Unit = { val mA = A.numCols val xIndices = x.indices val xValues = x.values diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 6e43d60bd03a3..f437d66cddb54 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -178,6 +178,14 @@ sealed trait Vector extends Serializable { */ @Since("2.0.0") def argmax: Int + + /** + * Calculate the dot product of this vector with another. + * + * If `size` does not match an [[IllegalArgumentException]] is thrown. + */ + @Since("3.0.0") + def dot(v: Vector): Double = BLAS.dot(this, v) } /** diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala index 332734bd28341..7d29d6dcea908 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala @@ -21,7 +21,7 @@ import java.util.Random import breeze.linalg.{CSCMatrix, Matrix => BM} import org.mockito.Mockito.when -import org.scalatest.mockito.MockitoSugar._ +import org.scalatestplus.mockito.MockitoSugar._ import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.ml.SparkMLFunSuite diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index 0a316f57f811b..c97dc2c3c06f8 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -380,4 +380,27 @@ class VectorsSuite extends SparkMLFunSuite { Vectors.sparse(-1, Array((1, 2.0))) } } + + test("dot product only supports vectors of same size") { + val vSize4 = Vectors.dense(arr) + val vSize1 = Vectors.zeros(1) + intercept[IllegalArgumentException]{ vSize1.dot(vSize4) } + } + + test("dense vector dot product") { + val dv = Vectors.dense(arr) + assert(dv.dot(dv) === 0.26) + } + + test("sparse vector dot product") { + val sv = Vectors.sparse(n, indices, values) + assert(sv.dot(sv) === 0.26) + } + + test("mixed sparse and dense vector dot product") { + val sv = Vectors.sparse(n, indices, values) + val dv = Vectors.dense(arr) + assert(sv.dot(dv) === 0.26) + assert(dv.dot(sv) === 0.26) + } } diff --git a/mllib/benchmarks/UDTSerializationBenchmark-jdk11-results.txt b/mllib/benchmarks/UDTSerializationBenchmark-jdk11-results.txt new file mode 100644 index 0000000000000..6f671405b8343 --- /dev/null +++ b/mllib/benchmarks/UDTSerializationBenchmark-jdk11-results.txt @@ -0,0 +1,12 @@ +================================================================================================ +VectorUDT de/serialization +================================================================================================ + +OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +VectorUDT de/serialization: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +serialize 269 292 13 0.0 269441.1 1.0X +deserialize 164 191 9 0.0 164314.6 1.6X + + diff --git a/mllib/benchmarks/UDTSerializationBenchmark-results.txt b/mllib/benchmarks/UDTSerializationBenchmark-results.txt index 169f4c60c748e..a0c853e99014b 100644 --- a/mllib/benchmarks/UDTSerializationBenchmark-results.txt +++ b/mllib/benchmarks/UDTSerializationBenchmark-results.txt @@ -2,12 +2,11 @@ VectorUDT de/serialization ================================================================================================ -Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.13.6 -Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz - -VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -serialize 144 / 206 0.0 143979.7 1.0X -deserialize 114 / 135 0.0 113802.6 1.3X +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +VectorUDT de/serialization: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +serialize 271 294 12 0.0 271054.3 1.0X +deserialize 190 192 2 0.0 189706.1 1.4X diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 58815434cbdaf..9eac8ed22a3f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -62,6 +62,39 @@ private[ml] trait PredictorParams extends Params } SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) } + + /** + * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset, + * and put it in an RDD with strong types. + */ + protected def extractInstances(dataset: Dataset[_]): RDD[Instance] = { + val w = this match { + case p: HasWeightCol => + if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) { + col($(p.weightCol)).cast(DoubleType) + } else { + lit(1.0) + } + } + + dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } + } + + /** + * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset, + * and put it in an RDD with strong types. + * Validate the output instances with the given function. + */ + protected def extractInstances(dataset: Dataset[_], + validateInstance: Instance => Unit): RDD[Instance] = { + extractInstances(dataset).map { instance => + validateInstance(instance) + instance + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index b6b02e77909bd..9ac673078d4ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.{MetadataUtils, SchemaUtils} @@ -42,6 +42,22 @@ private[spark] trait ClassifierParams val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT) } + + /** + * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset, + * and put it in an RDD with strong types. + * Validates the label on the classifier is a valid integer in the range [0, numClasses). + */ + protected def extractInstances(dataset: Dataset[_], + numClasses: Int): RDD[Instance] = { + val validateInstance = (instance: Instance) => { + val label = instance.label + require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" + + s" dataset with invalid label $label. Labels must be integers in range" + + s" [0, $numClasses).") + } + extractInstances(dataset, validateInstance) + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 6bd8a26f5b1a8..2d0212f36fad4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -22,7 +22,7 @@ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since -import org.apache.spark.ml.feature.{Instance, LabeledPoint} +import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ @@ -34,9 +34,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.{col, lit, udf} -import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.functions.{col, udf} /** * Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning) @@ -116,9 +115,8 @@ class DecisionTreeClassifier @Since("1.4.0") ( dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr => instr.logPipelineStage(this) instr.logDataset(dataset) - val categoricalFeatures: Map[Int, Int] = - MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val numClasses: Int = getNumClasses(dataset) + val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val numClasses = getNumClasses(dataset) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -126,13 +124,7 @@ class DecisionTreeClassifier @Since("1.4.0") ( s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } validateNumClasses(numClasses) - val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - validateLabel(label, numClasses) - Instance(label, weight, features) - } + val instances = extractInstances(dataset, numClasses) val strategy = getOldStrategy(categoricalFeatures, numClasses) instr.logNumClasses(numClasses) instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol, diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 78503585261bf..e467228b4cc14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -36,9 +36,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Dataset, Row} -import org.apache.spark.sql.functions.{col, lit} /** Params for linear SVM Classifier. */ private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam @@ -161,12 +159,7 @@ class LinearSVC @Since("2.2.0") ( override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra) override protected def train(dataset: Dataset[_]): LinearSVCModel = instrumented { instr => - val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = - dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val instances = extractInstances(dataset) instr.logPipelineStage(this) instr.logDataset(dataset) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 0997c1e7b38d6..af6e2b39ecb60 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -40,9 +40,8 @@ import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, Multiclas import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.{DataType, DoubleType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils @@ -492,12 +491,7 @@ class LogisticRegression @Since("1.2.0") ( protected[spark] def train( dataset: Dataset[_], handlePersistence: Boolean): LogisticRegressionModel = instrumented { instr => - val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = - dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val instances = extractInstances(dataset) if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 47b8a8df637b9..41db6f3f44342 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} import org.apache.spark.ml.feature.OneHotEncoderModel -import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index e97af0582d358..205f565aa2685 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasWeightCol @@ -28,7 +29,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{Dataset, Row} -import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.functions.col /** * Params for Naive Bayes Classifiers. @@ -137,17 +138,14 @@ class NaiveBayes @Since("1.5.0") ( s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } - val modelTypeValue = $(modelType) - val requireValues: Vector => Unit = { - modelTypeValue match { - case Multinomial => - requireNonnegativeValues - case Bernoulli => - requireZeroOneBernoulliValues - case _ => - // This should never happen. - throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.") - } + val validateInstance = $(modelType) match { + case Multinomial => + (instance: Instance) => requireNonnegativeValues(instance.features) + case Bernoulli => + (instance: Instance) => requireZeroOneBernoulliValues(instance.features) + case _ => + // This should never happen. + throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.") } instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol, @@ -155,17 +153,15 @@ class NaiveBayes @Since("1.5.0") ( val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size instr.logNumFeatures(numFeatures) - val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) // Aggregates term frequencies per label. // TODO: Calling aggregateByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd - .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2))) - }.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))( + val aggregated = extractInstances(dataset, validateInstance).map { instance => + (instance.label, (instance.weight, instance.features)) + }.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))( seqOp = { case ((weightSum, featureSum, count), (weight, features)) => - requireValues(features) BLAS.axpy(weight, features, featureSum) (weightSum + weight, featureSum, count + 1) }, diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 86caa1247e77f..979eb5e5448a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -33,8 +33,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix, Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.storage.StorageLevel @@ -111,28 +111,32 @@ class GaussianMixtureModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - var predictionColNames = Seq.empty[String] - var predictionColumns = Seq.empty[Column] - - if ($(predictionCol).nonEmpty) { - val predUDF = udf((vector: Vector) => predict(vector)) - predictionColNames :+= $(predictionCol) - predictionColumns :+= predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)) - } + val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol)) + var outputData = dataset + var numColsOutput = 0 if ($(probabilityCol).nonEmpty) { val probUDF = udf((vector: Vector) => predictProbability(vector)) - predictionColNames :+= $(probabilityCol) - predictionColumns :+= probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)) + outputData = outputData.withColumn($(probabilityCol), probUDF(vectorCol)) + numColsOutput += 1 + } + + if ($(predictionCol).nonEmpty) { + if ($(probabilityCol).nonEmpty) { + val predUDF = udf((vector: Vector) => vector.argmax) + outputData = outputData.withColumn($(predictionCol), predUDF(col($(probabilityCol)))) + } else { + val predUDF = udf((vector: Vector) => predict(vector)) + outputData = outputData.withColumn($(predictionCol), predUDF(vectorCol)) + } + numColsOutput += 1 } - if (predictionColNames.nonEmpty) { - dataset.withColumns(predictionColNames, predictionColumns) - } else { + if (numColsOutput == 0) { this.logWarning(s"$uid: GaussianMixtureModel.transform() does nothing" + " because no output columns were set.") - dataset.toDF() } + outputData.toDF } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 2a7b3c579b078..09e8e7b232f3a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -59,6 +59,28 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va @Since("1.2.0") def setMetricName(value: String): this.type = set(metricName, value) + /** + * param for number of bins to down-sample the curves (ROC curve, PR curve) in area + * computation. If 0, no down-sampling will occur. + * Default: 1000. + * @group expertParam + */ + @Since("3.0.0") + val numBins: IntParam = new IntParam(this, "numBins", "Number of bins to down-sample " + + "the curves (ROC curve, PR curve) in area computation. If 0, no down-sampling will occur. " + + "Must be >= 0.", + ParamValidators.gtEq(0)) + + /** @group expertGetParam */ + @Since("3.0.0") + def getNumBins: Int = $(numBins) + + /** @group expertSetParam */ + @Since("3.0.0") + def setNumBins(value: Int): this.type = set(numBins, value) + + setDefault(numBins -> 1000) + /** @group setParam */ @Since("1.5.0") def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) @@ -94,7 +116,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va case Row(rawPrediction: Double, label: Double, weight: Double) => (rawPrediction, label, weight) } - val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights) + val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights, $(numBins)) val metric = $(metricName) match { case "areaUnderROC" => metrics.areaUnderROC() case "areaUnderPR" => metrics.areaUnderPR() @@ -104,10 +126,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va } @Since("1.5.0") - override def isLargerBetter: Boolean = $(metricName) match { - case "areaUnderROC" => true - case "areaUnderPR" => true - } + override def isLargerBetter: Boolean = true @Since("1.4.1") override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index dd667a85fa598..b0cafefe420a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since -import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} +import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics @@ -43,13 +43,14 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui * - `"mse"`: mean squared error * - `"r2"`: R^2^ metric * - `"mae"`: mean absolute error + * - `"var"`: explained variance * * @group param */ @Since("1.4.0") val metricName: Param[String] = { - val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae")) - new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae)", allowedParams) + val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae", "var")) + new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae|var)", allowedParams) } /** @group getParam */ @@ -60,6 +61,25 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.4.0") def setMetricName(value: String): this.type = set(metricName, value) + /** + * param for whether the regression is through the origin. + * Default: false. + * @group expertParam + */ + @Since("3.0.0") + val throughOrigin: BooleanParam = new BooleanParam(this, "throughOrigin", + "Whether the regression is through the origin.") + + /** @group expertGetParam */ + @Since("3.0.0") + def getThroughOrigin: Boolean = $(throughOrigin) + + /** @group expertSetParam */ + @Since("3.0.0") + def setThroughOrigin(value: Boolean): this.type = set(throughOrigin, value) + + setDefault(throughOrigin -> false) + /** @group setParam */ @Since("1.4.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) @@ -86,22 +106,20 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui .rdd .map { case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) } - val metrics = new RegressionMetrics(predictionAndLabelsWithWeights) - val metric = $(metricName) match { + val metrics = new RegressionMetrics(predictionAndLabelsWithWeights, $(throughOrigin)) + $(metricName) match { case "rmse" => metrics.rootMeanSquaredError case "mse" => metrics.meanSquaredError case "r2" => metrics.r2 case "mae" => metrics.meanAbsoluteError + case "var" => metrics.explainedVariance } - metric } @Since("1.4.0") override def isLargerBetter: Boolean = $(metricName) match { - case "rmse" => false - case "mse" => false - case "r2" => true - case "mae" => false + case "r2" | "var" => true + case _ => false } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 2b0862c60fdf7..c4daf64dfc5f0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -75,30 +75,40 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) val schema = dataset.schema val inputType = schema($(inputCol)).dataType val td = $(threshold) + val metadata = outputSchema($(outputCol)).metadata - val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 } - val binarizerVector = udf { (data: Vector) => - val indices = ArrayBuilder.make[Int] - val values = ArrayBuilder.make[Double] - - data.foreachActive { (index, value) => - if (value > td) { - indices += index - values += 1.0 + val binarizerUDF = inputType match { + case DoubleType => + udf { in: Double => if (in > td) 1.0 else 0.0 } + + case _: VectorUDT if td >= 0 => + udf { vector: Vector => + val indices = ArrayBuilder.make[Int] + val values = ArrayBuilder.make[Double] + vector.foreachActive { (index, value) => + if (value > td) { + indices += index + values += 1.0 + } + } + Vectors.sparse(vector.size, indices.result(), values.result()).compressed } - } - Vectors.sparse(data.size, indices.result(), values.result()).compressed + case _: VectorUDT if td < 0 => + this.logWarning(s"Binarization operations on sparse dataset with negative threshold " + + s"$td will build a dense output, so take care when applying to sparse input.") + udf { vector: Vector => + val values = Array.fill(vector.size)(1.0) + vector.foreachActive { (index, value) => + if (value <= td) { + values(index) = 0.0 + } + } + Vectors.dense(values).compressed + } } - val metadata = outputSchema($(outputCol)).metadata - - inputType match { - case DoubleType => - dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata)) - case _: VectorUDT => - dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata)) - } + dataset.withColumn($(outputCol), binarizerUDF(col($(inputCol))), metadata) } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 32d98151bdcff..84d6a536ccca8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import edu.emory.mathcs.jtransforms.dct._ +import org.jtransforms.dct._ import org.apache.spark.annotation.Since import org.apache.spark.ml.UnaryTransformer diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 5bfaa3b7f3f52..f7a83cdd41a90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -167,25 +167,38 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("2.3.0") def setOutputCols(value: Array[String]): this.type = set(outputCols, value) - private[feature] def getInOutCols: (Array[String], Array[String]) = { - require((isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) || - (!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)), - "QuantileDiscretizer only supports setting either inputCol/outputCol or" + - "inputCols/outputCols." - ) + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), + Seq(outputCols)) if (isSet(inputCol)) { - (Array($(inputCol)), Array($(outputCol))) - } else { - require($(inputCols).length == $(outputCols).length, - "inputCols number do not match outputCols") - ($(inputCols), $(outputCols)) + require(!isSet(numBucketsArray), + s"numBucketsArray can't be set for single-column QuantileDiscretizer.") } - } - @Since("1.6.0") - override def transformSchema(schema: StructType): StructType = { - val (inputColNames, outputColNames) = getInOutCols + if (isSet(inputCols)) { + require(getInputCols.length == getOutputCols.length, + s"QuantileDiscretizer $this has mismatched Params " + + s"for multi-column transform. Params (inputCols, outputCols) should have " + + s"equal lengths, but they have different lengths: " + + s"(${getInputCols.length}, ${getOutputCols.length}).") + if (isSet(numBucketsArray)) { + require(getInputCols.length == getNumBucketsArray.length, + s"QuantileDiscretizer $this has mismatched Params " + + s"for multi-column transform. Params (inputCols, outputCols, numBucketsArray) " + + s"should have equal lengths, but they have different lengths: " + + s"(${getInputCols.length}, ${getOutputCols.length}, ${getNumBucketsArray.length}).") + require(!isSet(numBuckets), + s"exactly one of numBuckets, numBucketsArray Params to be set, but both are set." ) + } + } + + val (inputColNames, outputColNames) = if (isSet(inputCols)) { + ($(inputCols).toSeq, $(outputCols).toSeq) + } else { + (Seq($(inputCol)), Seq($(outputCol))) + } val existingFields = schema.fields var outputFields = existingFields inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 6c0d5fc70ab4e..df7d17059980b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -392,7 +392,7 @@ class RFormulaModel private[feature]( } } - private def checkCanTransform(schema: StructType) { + private def checkCanTransform(schema: StructType): Unit = { val columnNames = schema.map(_.name) require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index fb7334d41ba44..bf6e8ec8f37b8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats import org.json4s.JsonDSL._ -import org.apache.spark.{Dependency, Partitioner, ShuffleDependency, SparkContext} +import org.apache.spark.{Dependency, Partitioner, ShuffleDependency, SparkContext, SparkException} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} @@ -42,7 +42,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -564,6 +564,13 @@ object ALSModel extends MLReadable[ALSModel] { * r is greater than 0 and 0 if r is less than or equal to 0. The ratings then act as 'confidence' * values related to strength of indicated user * preferences rather than explicit ratings given to items. + * + * Note: the input rating dataset to the ALS implementation should be deterministic. + * Nondeterministic data can cause failure during fitting ALS model. + * For example, an order-sensitive operation like sampling after a repartition makes dataset + * output nondeterministic, like `dataset.repartition(2).sample(false, 0.5, 1618)`. + * Checkpointing sampled dataset or adding a sort before sampling can help make the dataset + * deterministic. */ @Since("1.3.0") class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] with ALSParams @@ -794,7 +801,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { * Given a triangular matrix in the order of fillXtX above, compute the full symmetric square * matrix that it represents, storing it into destMatrix. */ - private def fillAtA(triAtA: Array[Double], lambda: Double) { + private def fillAtA(triAtA: Array[Double], lambda: Double): Unit = { var i = 0 var pos = 0 var a = 0.0 @@ -1666,6 +1673,13 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { } } val merged = srcOut.groupByKey(new ALSPartitioner(dstInBlocks.partitions.length)) + + // SPARK-28927: Nondeterministic RDDs causes inconsistent in/out blocks in case of rerun. + // It can cause runtime error when matching in/out user/item blocks. + val isBlockRDDNondeterministic = + dstInBlocks.outputDeterministicLevel == DeterministicLevel.INDETERMINATE || + srcOutBlocks.outputDeterministicLevel == DeterministicLevel.INDETERMINATE + dstInBlocks.join(merged).mapValues { case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) => val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks) @@ -1686,7 +1700,19 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { val encoded = srcEncodedIndices(i) val blockId = srcEncoder.blockId(encoded) val localIndex = srcEncoder.localIndex(encoded) - val srcFactor = sortedSrcFactors(blockId)(localIndex) + var srcFactor: Array[Float] = null + try { + srcFactor = sortedSrcFactors(blockId)(localIndex) + } catch { + case a: ArrayIndexOutOfBoundsException if isBlockRDDNondeterministic => + val errMsg = "A failure detected when matching In/Out blocks of users/items. " + + "Because at least one In/Out block RDD is found to be nondeterministic now, " + + "the issue is probably caused by nondeterministic input data. You can try to " + + "checkpoint training data to make it deterministic. If you do `repartition` + " + + "`sample` or `randomSplit`, you can also try to sort it before `sample` or " + + "`randomSplit` to make it deterministic." + throw new SparkException(errMsg, a) + } val rating = ratings(i) if (implicitPrefs) { // Extension to the original paper to handle rating < 0. confidence is a function diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 106be1b78af47..602b5fac20d3b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -23,7 +23,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.feature.{Instance, LabeledPoint} +import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ @@ -34,9 +34,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType /** @@ -118,12 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val instances = extractInstances(dataset) val strategy = getOldStrategy(categoricalFeatures) instr.logPipelineStage(this) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index a226ca49e6deb..4dc0c247ce331 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1036,31 +1036,33 @@ class GeneralizedLinearRegressionModel private[ml] ( } override protected def transformImpl(dataset: Dataset[_]): DataFrame = { - var predictionColNames = Seq.empty[String] - var predictionColumns = Seq.empty[Column] - val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType) + var outputData = dataset + var numColsOutput = 0 - if ($(predictionCol).nonEmpty) { - val predictUDF = udf { (features: Vector, offset: Double) => predict(features, offset) } - predictionColNames :+= $(predictionCol) - predictionColumns :+= predictUDF(col($(featuresCol)), offset) + if (hasLinkPredictionCol) { + val predLinkUDF = udf((features: Vector, offset: Double) => predictLink(features, offset)) + outputData = outputData + .withColumn($(linkPredictionCol), predLinkUDF(col($(featuresCol)), offset)) + numColsOutput += 1 } - if (hasLinkPredictionCol) { - val predictLinkUDF = - udf { (features: Vector, offset: Double) => predictLink(features, offset) } - predictionColNames :+= $(linkPredictionCol) - predictionColumns :+= predictLinkUDF(col($(featuresCol)), offset) + if ($(predictionCol).nonEmpty) { + if (hasLinkPredictionCol) { + val predUDF = udf((eta: Double) => familyAndLink.fitted(eta)) + outputData = outputData.withColumn($(predictionCol), predUDF(col($(linkPredictionCol)))) + } else { + val predUDF = udf((features: Vector, offset: Double) => predict(features, offset)) + outputData = outputData.withColumn($(predictionCol), predUDF(col($(featuresCol)), offset)) + } + numColsOutput += 1 } - if (predictionColNames.nonEmpty) { - dataset.withColumns(predictionColNames, predictionColumns) - } else { + if (numColsOutput == 0) { this.logWarning(s"$uid: GeneralizedLinearRegressionModel.transform() does nothing" + " because no output columns were set.") - dataset.toDF() } + outputData.toDF } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index abf75d70ea028..4c600eac26b37 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -43,7 +43,6 @@ import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel} import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -320,13 +319,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String override protected def train(dataset: Dataset[_]): LinearRegressionModel = instrumented { instr => // Extract the number of features before deciding optimization solver. val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size - val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = dataset.select( - col($(labelCol)), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val instances = extractInstances(dataset) instr.logPipelineStage(this) instr.logDataset(dataset) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala index c0a1683d3cb6f..314cf422be87e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala @@ -28,7 +28,8 @@ import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} * @tparam Learner Concrete Estimator type * @tparam M Concrete Model type */ -private[spark] abstract class Regressor[ +@DeveloperApi +abstract class Regressor[ FeaturesType, Learner <: Regressor[FeaturesType, Learner, M], M <: RegressionModel[FeaturesType, M]] diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index 8f8a17171f980..6c194902a750b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -90,7 +90,7 @@ private[spark] class DecisionTreeMetadata( * Set number of splits for a continuous feature. * For a continuous feature, number of bins is number of splits plus 1. */ - def setNumSplits(featureIndex: Int, numSplits: Int) { + def setNumSplits(featureIndex: Int, numSplits: Int): Unit = { require(isContinuous(featureIndex), s"Only number of bin for a continuous feature can be set.") numBins(featureIndex) = numSplits + 1 diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 8cd4a7ca9493b..58a763257af20 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -205,21 +205,21 @@ private[spark] class OptionalInstrumentation private( protected override def logName: String = className - override def logInfo(msg: => String) { + override def logInfo(msg: => String): Unit = { instrumentation match { case Some(instr) => instr.logInfo(msg) case None => super.logInfo(msg) } } - override def logWarning(msg: => String) { + override def logWarning(msg: => String): Unit = { instrumentation match { case Some(instr) => instr.logWarning(msg) case None => super.logWarning(msg) } } - override def logError(msg: => String) { + override def logError(msg: => String): Unit = { instrumentation match { case Some(instr) => instr.logError(msg) case None => super.logError(msg) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 4617073f9decd..bafaafb720ed8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -347,7 +347,6 @@ private[python] class PythonMLLibAPI extends Serializable { data: JavaRDD[Vector], k: Int, maxIterations: Int, - runs: Int, initializationMode: String, seed: java.lang.Long, initializationSteps: Int, @@ -1312,7 +1311,7 @@ private[spark] abstract class SerDeBase { } } - private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler) + private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit } def dumps(obj: AnyRef): Array[Byte] = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index d86aa01c9195a..df888bc3d5d51 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -224,117 +224,11 @@ class LogisticRegressionWithSGD private[mllib] ( .setMiniBatchFraction(miniBatchFraction) override protected val validators = List(DataValidators.binaryLabelValidator) - /** - * Construct a LogisticRegression object with default parameters: {stepSize: 1.0, - * numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}. - */ - @Since("0.8.0") - @deprecated("Use ml.classification.LogisticRegression or LogisticRegressionWithLBFGS", "2.0.0") - def this() = this(1.0, 100, 0.01, 1.0) - override protected[mllib] def createModel(weights: Vector, intercept: Double) = { new LogisticRegressionModel(weights, intercept) } } -/** - * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent. - * - * @note Labels used in Logistic Regression should be {0, 1} - */ -@Since("0.8.0") -@deprecated("Use ml.classification.LogisticRegression or LogisticRegressionWithLBFGS", "2.0.0") -object LogisticRegressionWithSGD { - // NOTE(shivaram): We use multiple train methods instead of default arguments to support - // Java programs. - - /** - * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed - * number of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in - * gradient descent are initialized using the initial weights provided. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to - * the number of features in the data. - * - * @note Labels used in Logistic Regression should be {0, 1} - */ - @Since("1.0.0") - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - miniBatchFraction: Double, - initialWeights: Vector): LogisticRegressionModel = { - new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction) - .run(input, initialWeights) - } - - /** - * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed - * number of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param miniBatchFraction Fraction of data to be used per iteration. - * - * @note Labels used in Logistic Regression should be {0, 1} - */ - @Since("1.0.0") - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - miniBatchFraction: Double): LogisticRegressionModel = { - new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction) - .run(input) - } - - /** - * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed - * number of iterations of gradient descent using the specified step size. We use the entire data - * set to update the gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. - * @param stepSize Step size to be used for each iteration of Gradient Descent. - * @param numIterations Number of iterations of gradient descent to run. - * @return a LogisticRegressionModel which has the weights and offset from training. - * - * @note Labels used in Logistic Regression should be {0, 1} - */ - @Since("1.0.0") - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double): LogisticRegressionModel = { - train(input, numIterations, stepSize, 1.0) - } - - /** - * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed - * number of iterations of gradient descent using a step size of 1.0. We use the entire data set - * to update the gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @return a LogisticRegressionModel which has the weights and offset from training. - * - * @note Labels used in Logistic Regression should be {0, 1} - */ - @Since("1.0.0") - def train( - input: RDD[LabeledPoint], - numIterations: Int): LogisticRegressionModel = { - train(input, numIterations, 1.0, 1.0) - } -} - /** * Train a classification model for Multinomial/Binary Logistic Regression using * Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 4bb79bc69eef4..278d61d916735 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -479,58 +479,6 @@ object KMeans { .run(data) } - /** - * Trains a k-means model using the given set of parameters. - * - * @param data Training points as an `RDD` of `Vector` types. - * @param k Number of clusters to create. - * @param maxIterations Maximum number of iterations allowed. - * @param runs This param has no effect since Spark 2.0.0. - * @param initializationMode The initialization algorithm. This can either be "random" or - * "k-means||". (default: "k-means||") - * @param seed Random seed for cluster initialization. Default is to generate seed based - * on system time. - */ - @Since("1.3.0") - @deprecated("Use train method without 'runs'", "2.1.0") - def train( - data: RDD[Vector], - k: Int, - maxIterations: Int, - runs: Int, - initializationMode: String, - seed: Long): KMeansModel = { - new KMeans().setK(k) - .setMaxIterations(maxIterations) - .setInitializationMode(initializationMode) - .setSeed(seed) - .run(data) - } - - /** - * Trains a k-means model using the given set of parameters. - * - * @param data Training points as an `RDD` of `Vector` types. - * @param k Number of clusters to create. - * @param maxIterations Maximum number of iterations allowed. - * @param runs This param has no effect since Spark 2.0.0. - * @param initializationMode The initialization algorithm. This can either be "random" or - * "k-means||". (default: "k-means||") - */ - @Since("0.8.0") - @deprecated("Use train method without 'runs'", "2.1.0") - def train( - data: RDD[Vector], - k: Int, - maxIterations: Int, - runs: Int, - initializationMode: String): KMeansModel = { - new KMeans().setK(k) - .setMaxIterations(maxIterations) - .setInitializationMode(initializationMode) - .run(data) - } - /** * Trains a k-means model using specified parameters and the default values for unspecified. */ @@ -544,21 +492,6 @@ object KMeans { .run(data) } - /** - * Trains a k-means model using specified parameters and the default values for unspecified. - */ - @Since("0.8.0") - @deprecated("Use train method without 'runs'", "2.1.0") - def train( - data: RDD[Vector], - k: Int, - maxIterations: Int, - runs: Int): KMeansModel = { - new KMeans().setK(k) - .setMaxIterations(maxIterations) - .run(data) - } - private[spark] def validateInitMode(initMode: String): Boolean = { initMode match { case KMeans.RANDOM => true diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index ff4ca0ac40fe2..c7d44e8752cd9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -269,7 +269,7 @@ class StreamingKMeans @Since("1.2.0") ( * @param data DStream containing vector data */ @Since("1.2.0") - def trainOn(data: DStream[Vector]) { + def trainOn(data: DStream[Vector]): Unit = { assertInitialized() data.foreachRDD { (rdd, time) => model = model.update(rdd, decayFactor, timeUnit) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index d34a7ca6c9c7f..f4e2040569f48 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -81,7 +81,7 @@ class BinaryClassificationMetrics @Since("3.0.0") ( * Unpersist intermediate RDDs used in the computation. */ @Since("1.0.0") - def unpersist() { + def unpersist(): Unit = { cumulativeCounts.unpersist() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 82f5b279846ba..b771e077b02ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -44,17 +44,6 @@ class ChiSqSelectorModel @Since("1.3.0") ( private val filterIndices = selectedFeatures.sorted - @deprecated("not intended for subclasses to use", "2.1.0") - protected def isSorted(array: Array[Int]): Boolean = { - var i = 1 - val len = array.length - while (i < len) { - if (array(i) < array(i-1)) return false - i += 1 - } - true - } - /** * Applies transformation on a vector. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index cb97742245689..1f5558dc2a50e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -285,7 +285,7 @@ private[spark] object BLAS extends Serializable with Logging { * @param x the vector x that contains the n elements. * @param A the symmetric matrix A. Size of n x n. */ - def syr(alpha: Double, x: Vector, A: DenseMatrix) { + def syr(alpha: Double, x: Vector, A: DenseMatrix): Unit = { val mA = A.numRows val nA = A.numCols require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA") @@ -299,7 +299,7 @@ private[spark] object BLAS extends Serializable with Logging { } } - private def syr(alpha: Double, x: DenseVector, A: DenseMatrix) { + private def syr(alpha: Double, x: DenseVector, A: DenseMatrix): Unit = { val nA = A.numRows val mA = A.numCols @@ -317,7 +317,7 @@ private[spark] object BLAS extends Serializable with Logging { } } - private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) { + private def syr(alpha: Double, x: SparseVector, A: DenseMatrix): Unit = { val mA = A.numCols val xIndices = x.indices val xValues = x.values diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index e474cfa002fad..0304fd88dcd9f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -155,7 +155,7 @@ sealed trait Matrix extends Serializable { * and column indices respectively with the type `Int`, and the final parameter is the * corresponding value in the matrix with type `Double`. */ - private[spark] def foreachActive(f: (Int, Int, Double) => Unit) + private[spark] def foreachActive(f: (Int, Int, Double) => Unit): Unit /** * Find the number of non-zero active values. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index b754fad0c1796..83a519326df75 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -204,6 +204,14 @@ sealed trait Vector extends Serializable { */ @Since("2.0.0") def asML: newlinalg.Vector + + /** + * Calculate the dot product of this vector with another. + * + * If `size` does not match an [[IllegalArgumentException]] is thrown. + */ + @Since("3.0.0") + def dot(v: Vector): Double = BLAS.dot(this, v) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 0d223de9b6f7e..f3b984948e483 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -153,7 +153,7 @@ class CoordinateMatrix @Since("1.0.0") ( } /** Determines the size by computing the max row/column index. */ - private def computeSize() { + private def computeSize(): Unit = { // Reduce will throw an exception if `entries` is empty. val (m1, n1) = entries.map(entry => (entry.i, entry.j)).reduce { case ((i1, j1), (i2, j2)) => (math.max(i1, i2), math.max(j1, j2)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 43f48befd014f..f25d86b30631a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -770,7 +770,7 @@ class RowMatrix @Since("1.0.0") ( } /** Updates or verifies the number of rows. */ - private def updateNumRows(m: Long) { + private def updateNumRows(m: Long): Unit = { if (nRows <= 0) { nRows = m } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index fa04f8eb5e796..d3b548832bb21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -107,7 +107,7 @@ class PoissonGenerator @Since("1.1.0") ( override def nextValue(): Double = rng.sample() @Since("1.1.0") - override def setSeed(seed: Long) { + override def setSeed(seed: Long): Unit = { rng.reseedRandomGenerator(seed) } @@ -132,7 +132,7 @@ class ExponentialGenerator @Since("1.3.0") ( override def nextValue(): Double = rng.sample() @Since("1.3.0") - override def setSeed(seed: Long) { + override def setSeed(seed: Long): Unit = { rng.reseedRandomGenerator(seed) } @@ -159,7 +159,7 @@ class GammaGenerator @Since("1.3.0") ( override def nextValue(): Double = rng.sample() @Since("1.3.0") - override def setSeed(seed: Long) { + override def setSeed(seed: Long): Unit = { rng.reseedRandomGenerator(seed) } @@ -187,7 +187,7 @@ class LogNormalGenerator @Since("1.3.0") ( override def nextValue(): Double = rng.sample() @Since("1.3.0") - override def setSeed(seed: Long) { + override def setSeed(seed: Long): Unit = { rng.reseedRandomGenerator(seed) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 12870f819b147..f3f15ba0d0f2c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -62,6 +62,13 @@ case class Rating @Since("0.8.0") ( * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of * indicated user * preferences rather than explicit ratings given to items. + * + * Note: the input rating RDD to the ALS implementation should be deterministic. + * Nondeterministic data can cause failure during fitting ALS model. + * For example, an order-sensitive operation like sampling after a repartition makes RDD + * output nondeterministic, like `rdd.repartition(2).sample(false, 0.5, 1618)`. + * Checkpointing sampled RDD or adding a sort before sampling can help make the RDD + * deterministic. */ @Since("0.8.0") class ALS private ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index ead9f5b300375..47bb1fa9127a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -24,7 +24,6 @@ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel import org.apache.spark.mllib.util.{Loader, Saveable} -import org.apache.spark.rdd.RDD /** * Regression model trained using Lasso. @@ -99,117 +98,7 @@ class LassoWithSGD private[mllib] ( .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) - /** - * Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100, - * regParam: 0.01, miniBatchFraction: 1.0}. - */ - @Since("0.8.0") - @deprecated("Use ml.regression.LinearRegression with elasticNetParam = 1.0. Note the default " + - "regParam is 0.01 for LassoWithSGD, but is 0.0 for LinearRegression.", "2.0.0") - def this() = this(1.0, 100, 0.01, 1.0) - override protected def createModel(weights: Vector, intercept: Double) = { new LassoModel(weights, intercept) } } - -/** - * Top-level methods for calling Lasso. - * - */ -@Since("0.8.0") -@deprecated("Use ml.regression.LinearRegression with elasticNetParam = 1.0. Note the default " + - "regParam is 0.01 for LassoWithSGD, but is 0.0 for LinearRegression.", "2.0.0") -object LassoWithSGD { - - /** - * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate a stochastic gradient. The weights used - * in gradient descent are initialized using the initial weights provided. - * - * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data - * matrix A as well as the corresponding right hand side label y - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size scaling to be used for the iterations of gradient descent. - * @param regParam Regularization parameter. - * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to - * the number of features in the data. - * - */ - @Since("1.0.0") - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double, - miniBatchFraction: Double, - initialWeights: Vector): LassoModel = { - new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction) - .run(input, initialWeights) - } - - /** - * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate a stochastic gradient. - * - * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data - * matrix A as well as the corresponding right hand side label y - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param regParam Regularization parameter. - * @param miniBatchFraction Fraction of data to be used per iteration. - * - */ - @Since("0.8.0") - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double, - miniBatchFraction: Double): LassoModel = { - new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) - } - - /** - * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. We use the entire data set to - * update the true gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data - * matrix A as well as the corresponding right hand side label y - * @param stepSize Step size to be used for each iteration of Gradient Descent. - * @param regParam Regularization parameter. - * @param numIterations Number of iterations of gradient descent to run. - * @return a LassoModel which has the weights and offset from training. - * - */ - @Since("0.8.0") - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double): LassoModel = { - train(input, numIterations, stepSize, regParam, 1.0) - } - - /** - * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using a step size of 1.0. We use the entire data set to - * compute the true gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data - * matrix A as well as the corresponding right hand side label y - * @param numIterations Number of iterations of gradient descent to run. - * @return a LassoModel which has the weights and offset from training. - * - */ - @Since("0.8.0") - def train( - input: RDD[LabeledPoint], - numIterations: Int): LassoModel = { - train(input, numIterations, 1.0, 0.01, 1.0) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index cb08216fbf690..f68ebc17e294d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -24,7 +24,6 @@ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel import org.apache.spark.mllib.util.{Loader, Saveable} -import org.apache.spark.rdd.RDD /** * Regression model trained using LinearRegression. @@ -100,109 +99,8 @@ class LinearRegressionWithSGD private[mllib] ( .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) - /** - * Construct a LinearRegression object with default parameters: {stepSize: 1.0, - * numIterations: 100, miniBatchFraction: 1.0}. - */ - @Since("0.8.0") - @deprecated("Use ml.regression.LinearRegression or LBFGS", "2.0.0") - def this() = this(1.0, 100, 0.0, 1.0) - override protected[mllib] def createModel(weights: Vector, intercept: Double) = { new LinearRegressionModel(weights, intercept) } } -/** - * Top-level methods for calling LinearRegression. - * - */ -@Since("0.8.0") -@deprecated("Use ml.regression.LinearRegression or LBFGS", "2.0.0") -object LinearRegressionWithSGD { - - /** - * Train a Linear Regression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate a stochastic gradient. The weights used - * in gradient descent are initialized using the initial weights provided. - * - * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data - * matrix A as well as the corresponding right hand side label y - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to - * the number of features in the data. - * - */ - @Since("1.0.0") - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - miniBatchFraction: Double, - initialWeights: Vector): LinearRegressionModel = { - new LinearRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction) - .run(input, initialWeights) - } - - /** - * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate a stochastic gradient. - * - * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data - * matrix A as well as the corresponding right hand side label y - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param miniBatchFraction Fraction of data to be used per iteration. - * - */ - @Since("0.8.0") - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - miniBatchFraction: Double): LinearRegressionModel = { - new LinearRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(input) - } - - /** - * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. We use the entire data set to - * compute the true gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data - * matrix A as well as the corresponding right hand side label y - * @param stepSize Step size to be used for each iteration of Gradient Descent. - * @param numIterations Number of iterations of gradient descent to run. - * @return a LinearRegressionModel which has the weights and offset from training. - * - */ - @Since("0.8.0") - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double): LinearRegressionModel = { - train(input, numIterations, stepSize, 1.0) - } - - /** - * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using a step size of 1.0. We use the entire data set to - * compute the true gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data - * matrix A as well as the corresponding right hand side label y - * @param numIterations Number of iterations of gradient descent to run. - * @return a LinearRegressionModel which has the weights and offset from training. - * - */ - @Since("0.8.0") - def train( - input: RDD[LabeledPoint], - numIterations: Int): LinearRegressionModel = { - train(input, numIterations, 1.0, 1.0) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 43c3154dd053b..1c3bdceab1d14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -24,8 +24,6 @@ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel import org.apache.spark.mllib.util.{Loader, Saveable} -import org.apache.spark.rdd.RDD - /** * Regression model trained using RidgeRegression. @@ -100,113 +98,7 @@ class RidgeRegressionWithSGD private[mllib] ( .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) - /** - * Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100, - * regParam: 0.01, miniBatchFraction: 1.0}. - */ - @Since("0.8.0") - @deprecated("Use ml.regression.LinearRegression with elasticNetParam = 0.0. Note the default " + - "regParam is 0.01 for RidgeRegressionWithSGD, but is 0.0 for LinearRegression.", "2.0.0") - def this() = this(1.0, 100, 0.01, 1.0) - override protected def createModel(weights: Vector, intercept: Double) = { new RidgeRegressionModel(weights, intercept) } } - -/** - * Top-level methods for calling RidgeRegression. - * - */ -@Since("0.8.0") -@deprecated("Use ml.regression.LinearRegression with elasticNetParam = 0.0. Note the default " + - "regParam is 0.01 for RidgeRegressionWithSGD, but is 0.0 for LinearRegression.", "2.0.0") -object RidgeRegressionWithSGD { - - /** - * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate a stochastic gradient. The weights used - * in gradient descent are initialized using the initial weights provided. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param regParam Regularization parameter. - * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to - * the number of features in the data. - * - */ - @Since("1.0.0") - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double, - miniBatchFraction: Double, - initialWeights: Vector): RidgeRegressionModel = { - new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run( - input, initialWeights) - } - - /** - * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate a stochastic gradient. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param regParam Regularization parameter. - * @param miniBatchFraction Fraction of data to be used per iteration. - * - */ - @Since("0.8.0") - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double, - miniBatchFraction: Double): RidgeRegressionModel = { - new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) - } - - /** - * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. We use the entire data set to - * compute the true gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. - * @param stepSize Step size to be used for each iteration of Gradient Descent. - * @param regParam Regularization parameter. - * @param numIterations Number of iterations of gradient descent to run. - * @return a RidgeRegressionModel which has the weights and offset from training. - * - */ - @Since("0.8.0") - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double): RidgeRegressionModel = { - train(input, numIterations, stepSize, regParam, 1.0) - } - - /** - * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using a step size of 1.0. We use the entire data set to - * compute the true gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @return a RidgeRegressionModel which has the weights and offset from training. - * - */ - @Since("0.8.0") - def train( - input: RDD[LabeledPoint], - numIterations: Int): RidgeRegressionModel = { - train(input, numIterations, 1.0, 0.01, 1.0) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala index 7f84be9f37822..b6eb10e9de00a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala @@ -65,7 +65,7 @@ object KMeansDataGenerator { } @Since("0.8.0") - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 6) { // scalastyle:off println println("Usage: KMeansGenerator " + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 58fd010e4905f..c218681b3375e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -189,7 +189,7 @@ object LinearDataGenerator { } @Since("0.8.0") - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 2) { // scalastyle:off println println("Usage: LinearDataGenerator " + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala index 68835bc79677f..7e9d9465441c9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -65,7 +65,7 @@ object LogisticRegressionDataGenerator { } @Since("0.8.0") - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length != 5) { // scalastyle:off println println("Usage: LogisticRegressionGenerator " + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 42c5bcdd39f76..7a308a5ec25c0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -54,7 +54,7 @@ import org.apache.spark.rdd.RDD @Since("0.8.0") object MFDataGenerator { @Since("0.8.0") - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 2) { // scalastyle:off println println("Usage: MFDataGenerator " + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 6d15a6bb01e4e..9198334ba02a1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -173,7 +173,7 @@ object MLUtils extends Logging { * @see `org.apache.spark.mllib.util.MLUtils.loadLibSVMFile` */ @Since("1.0.0") - def saveAsLibSVMFile(data: RDD[LabeledPoint], dir: String) { + def saveAsLibSVMFile(data: RDD[LabeledPoint], dir: String): Unit = { // TODO: allow to specify label precision and feature precision. val dataStr = data.map { case LabeledPoint(label, features) => val sb = new StringBuilder(label.toString) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala index c9468606544db..9f6ba025aedde 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -37,7 +37,7 @@ import org.apache.spark.rdd.RDD object SVMDataGenerator { @Since("0.8.0") - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { if (args.length < 2) { // scalastyle:off println println("Usage: SVMGenerator " + diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java index b7956b6fd3e9a..69952f0b64ac2 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -20,7 +20,7 @@ import java.util.Arrays; import java.util.List; -import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; +import org.jtransforms.dct.DoubleDCT_1D; import org.junit.Assert; import org.junit.Test; diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java index c04e2e69541ba..208a5aaa2bb15 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java @@ -50,11 +50,8 @@ public void runLRUsingConstructor() { List validationData = LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); - LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD(); + LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD(1.0, 100, 1.0, 1.0); lrImpl.setIntercept(true); - lrImpl.optimizer().setStepSize(1.0) - .setRegParam(1.0) - .setNumIterations(100); LogisticRegressionModel model = lrImpl.run(testRDD.rdd()); int numAccurate = validatePrediction(validationData, model); @@ -72,8 +69,8 @@ public void runLRUsingStaticMethods() { List validationData = LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); - LogisticRegressionModel model = LogisticRegressionWithSGD.train( - testRDD.rdd(), 100, 1.0, 1.0); + LogisticRegressionModel model = new LogisticRegressionWithSGD(1.0, 100, 0.01, 1.0) + .run(testRDD.rdd()); int numAccurate = validatePrediction(validationData, model); Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java index 270e636f82117..a9a8b7f2b88d6 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java @@ -42,11 +42,11 @@ public void runKMeansUsingStaticMethods() { Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0); JavaRDD data = jsc.parallelize(points, 2); - KMeansModel model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.K_MEANS_PARALLEL()); + KMeansModel model = KMeans.train(data.rdd(), 1, 1, KMeans.K_MEANS_PARALLEL()); assertEquals(1, model.clusterCenters().length); assertEquals(expectedCenter, model.clusterCenters()[0]); - model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.RANDOM()); + model = KMeans.train(data.rdd(), 1, 1, KMeans.RANDOM()); assertEquals(expectedCenter, model.clusterCenters()[0]); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java index 1458cc72bc17f..35ad24bc2a84f 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java @@ -51,10 +51,7 @@ public void runLassoUsingConstructor() { List validationData = LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); - LassoWithSGD lassoSGDImpl = new LassoWithSGD(); - lassoSGDImpl.optimizer().setStepSize(1.0) - .setRegParam(0.01) - .setNumIterations(20); + LassoWithSGD lassoSGDImpl = new LassoWithSGD(1.0, 20, 0.01, 1.0); LassoModel model = lassoSGDImpl.run(testRDD.rdd()); int numAccurate = validatePrediction(validationData, model); @@ -72,7 +69,7 @@ public void runLassoUsingStaticMethods() { List validationData = LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); - LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0); + LassoModel model = new LassoWithSGD(1.0, 100, 0.01, 1.0).run(testRDD.rdd()); int numAccurate = validatePrediction(validationData, model); Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java index 86c723aa00746..7e87588c4f0f6 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java @@ -33,7 +33,7 @@ private static int validatePrediction( List validationData, LinearRegressionModel model) { int numAccurate = 0; for (LabeledPoint point : validationData) { - Double prediction = model.predict(point.features()); + double prediction = model.predict(point.features()); // A prediction is off if the prediction is more than 0.5 away from expected value. if (Math.abs(prediction - point.label()) <= 0.5) { numAccurate++; @@ -53,7 +53,7 @@ public void runLinearRegressionUsingConstructor() { List validationData = LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); - LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); + LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0); linSGDImpl.setIntercept(true); LinearRegressionModel model = linSGDImpl.run(testRDD.rdd()); @@ -72,7 +72,8 @@ public void runLinearRegressionUsingStaticMethods() { List validationData = LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); - LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100); + LinearRegressionModel model = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0) + .run(testRDD.rdd()); int numAccurate = validatePrediction(validationData, model); Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); @@ -85,7 +86,7 @@ public void testPredictJavaRDD() { double[] weights = {10, 10}; JavaRDD testRDD = jsc.parallelize( LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); - LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); + LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0); LinearRegressionModel model = linSGDImpl.run(testRDD.rdd()); JavaRDD vectors = testRDD.map(LabeledPoint::features); JavaRDD predictions = model.predict(vectors); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java index 5a9389c424b44..63441950cd18f 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java @@ -34,7 +34,7 @@ private static double predictionError(List validationData, RidgeRegressionModel model) { double errorSum = 0; for (LabeledPoint point : validationData) { - Double prediction = model.predict(point.features()); + double prediction = model.predict(point.features()); errorSum += (prediction - point.label()) * (prediction - point.label()); } return errorSum / validationData.size(); @@ -60,11 +60,7 @@ public void runRidgeRegressionUsingConstructor() { new ArrayList<>(data.subList(0, numExamples))); List validationData = data.subList(numExamples, 2 * numExamples); - RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(); - ridgeSGDImpl.optimizer() - .setStepSize(1.0) - .setRegParam(0.0) - .setNumIterations(200); + RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(1.0, 200, 0.0, 1.0); RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd()); double unRegularizedErr = predictionError(validationData, model); @@ -85,10 +81,12 @@ public void runRidgeRegressionUsingStaticMethods() { new ArrayList<>(data.subList(0, numExamples))); List validationData = data.subList(numExamples, 2 * numExamples); - RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0); + RidgeRegressionModel model = new RidgeRegressionWithSGD(1.0, 200, 0.0, 1.0) + .run(testRDD.rdd()); double unRegularizedErr = predictionError(validationData, model); - model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.1); + model = new RidgeRegressionWithSGD(1.0, 200, 0.1, 1.0) + .run(testRDD.rdd()); double regularizedErr = predictionError(validationData, model); Assert.assertTrue(regularizedErr < unRegularizedErr); diff --git a/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala index e2ee7c05ab399..f2343b7a88560 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala @@ -25,7 +25,7 @@ import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.when import org.scalatest.BeforeAndAfterEach import org.scalatest.concurrent.Eventually -import org.scalatest.mockito.MockitoSugar.mock +import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamMap diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 1183cb0617610..e6025a5a53ca6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.when -import org.scalatest.mockito.MockitoSugar.mock +import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.Pipeline.SharedReadWrite diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 9f2053dcc91fc..3ebf8a83a892c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -44,7 +44,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { private val seed = 42 - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() categoricalDataPointsRDD = sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()).map(_.asML) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 467f13f808a01..af3dd201d3b51 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -55,7 +55,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { private val eps: Double = 1e-5 private val absEps: Double = 1e-8 - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2) .map(_.asML) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 0f0954e5d8cac..f03ed0b76eb80 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -42,7 +42,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ private var orderedLabeledPoints5_20: RDD[LabeledPoint] = _ - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() orderedLabeledPoints50_1000 = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)) @@ -56,7 +56,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { // Tests calling train() ///////////////////////////////////////////////////////////////////////////// - def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier) { + def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier): Unit = { val categoricalFeatures = Map.empty[Int, Int] val numClasses = 2 val newRF = rf diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index c1a156959618e..f4f858c3e92dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -76,6 +76,10 @@ class RegressionEvaluatorSuite // mae evaluator.setMetricName("mae") assert(evaluator.evaluate(predictions) ~== 0.08399089 absTol 0.01) + + // var + evaluator.setMetricName("var") + assert(evaluator.evaluate(predictions) ~== 63.6944519 absTol 0.01) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 05d4a6ee2dabf..91bec50fb904f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -101,6 +101,20 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { } } + test("Binarizer should support sparse vector with negative threshold") { + val data = Seq( + (Vectors.sparse(3, Array(1), Array(0.5)), Vectors.dense(Array(1.0, 1.0, 1.0))), + (Vectors.dense(Array(0.0, 0.5, 0.0)), Vectors.dense(Array(1.0, 1.0, 1.0)))) + val df = data.toDF("feature", "expected") + val binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(-0.5) + binarizer.transform(df).select("binarized_feature", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x == y, "The feature value is not correct after binarization.") + } + } test("read/write") { val t = new Binarizer() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index 985e396000d05..079dabb3665be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D +import org.jtransforms.dct.DoubleDCT_1D import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index ae086d32d6d0b..6f6ab26cbac43 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql._ @@ -423,33 +424,92 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { assert(readDiscretizer.hasDefault(readDiscretizer.outputCol)) } - test("Multiple Columns: Both inputCol and inputCols are set") { + test("Multiple Columns: Mismatched sizes of inputCols/outputCols") { val spark = this.spark import spark.implicits._ val discretizer = new QuantileDiscretizer() - .setInputCol("input") - .setOutputCol("result") + .setInputCols(Array("input")) + .setOutputCols(Array("result1", "result2")) .setNumBuckets(3) - .setInputCols(Array("input1", "input2")) val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) .map(Tuple1.apply).toDF("input") - // When both inputCol and inputCols are set, we throw Exception. intercept[IllegalArgumentException] { discretizer.fit(df) } } - test("Multiple Columns: Mismatched sizes of inputCols / outputCols") { + test("Multiple Columns: Mismatched sizes of inputCols/numBucketsArray") { val spark = this.spark import spark.implicits._ val discretizer = new QuantileDiscretizer() - .setInputCols(Array("input")) + .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) - .setNumBuckets(3) + .setNumBucketsArray(Array(2, 5, 10)) + val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0) + val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0) + val df = data1.zip(data2).toSeq.toDF("input1", "input2") + intercept[IllegalArgumentException] { + discretizer.fit(df) + } + } + + test("Multiple Columns: Set both of numBuckets/numBucketsArray") { + val spark = this.spark + import spark.implicits._ + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("result1", "result2")) + .setNumBucketsArray(Array(2, 5)) + .setNumBuckets(2) + val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0) + val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0) + val df = data1.zip(data2).toSeq.toDF("input1", "input2") + intercept[IllegalArgumentException] { + discretizer.fit(df) + } + } + + test("Setting numBucketsArray for Single-Column QuantileDiscretizer") { + val spark = this.spark + import spark.implicits._ + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBucketsArray(Array(2, 5)) val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) .map(Tuple1.apply).toDF("input") intercept[IllegalArgumentException] { discretizer.fit(df) } } + + test("Assert exception is thrown if both multi-column and single-column params are set") { + val spark = this.spark + import spark.implicits._ + val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") + ParamsSuite.testExclusiveParams(new QuantileDiscretizer, df, ("inputCol", "feature1"), + ("inputCols", Array("feature1", "feature2"))) + ParamsSuite.testExclusiveParams(new QuantileDiscretizer, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("outputCols", Array("result1", "result2"))) + // this should fail because at least one of inputCol and inputCols must be set + ParamsSuite.testExclusiveParams(new QuantileDiscretizer, df, ("outputCol", "feature1")) + } + + test("Setting inputCol without setting outputCol") { + val spark = this.spark + import spark.implicits._ + + val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + .map(Tuple1.apply).toDF("input") + val numBuckets = 2 + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setNumBuckets(numBuckets) + val model = discretizer.fit(df) + val result = model.transform(df) + + val observedNumBuckets = result.select(discretizer.getOutputCol).distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala index add1cc17ea057..efd56f7073a19 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -25,7 +25,7 @@ class RFormulaParserSuite extends SparkFunSuite { formula: String, label: String, terms: Seq[String], - schema: StructType = new StructType) { + schema: StructType = new StructType): Unit = { val resolved = RFormulaParser.parse(formula).resolve(schema) assert(resolved.label == label) val simpleTerms = terms.map { t => diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 630e785e59507..49ebcb385640e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -40,7 +40,7 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { private val seed = 42 - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() categoricalDataPointsRDD = sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints().map(_.asML)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 884fe2d11bf5a..60007975c3b52 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -47,7 +47,7 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { private var trainData: RDD[LabeledPoint] = _ private var validationData: RDD[LabeledPoint] = _ - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2) .map(_.asML) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index c6dabd1b28829..0243e8d2335ee 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -38,7 +38,7 @@ class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{ private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() orderedLabeledPoints50_1000 = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) @@ -49,7 +49,7 @@ class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{ // Tests calling train() ///////////////////////////////////////////////////////////////////////////// - def regressionTestWithContinuousFeatures(rf: RandomForestRegressor) { + def regressionTestWithContinuousFeatures(rf: RandomForestRegressor): Unit = { val categoricalFeaturesInfo = Map.empty[Int, Int] val newRF = rf .setImpurity("variance") diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index a63ab913f2c22..ae5e979983b4f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -485,7 +485,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } } - def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: OldStrategy) { + def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures( + strategy: OldStrategy): Unit = { val numFeatures = 50 val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000) val rdd = sc.parallelize(arr).map(_.asML.toInstance) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 8a0a48ff6095b..90079c9848823 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -56,7 +56,7 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => sc.setCheckpointDir(checkpointDir) } - override def afterAll() { + override def afterAll(): Unit = { try { Utils.deleteRecursively(new File(checkpointDir)) } finally { @@ -127,7 +127,7 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => dataframe: DataFrame, transformer: Transformer, expectedMessagePart : String, - firstResultCol: String) { + firstResultCol: String): Unit = { withClue(s"""Expected message part "${expectedMessagePart}" is not found in DF test.""") { val exceptionOnDf = intercept[Throwable] { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 5cf4377768516..d4e9da3c6263e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -206,7 +206,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w def validatePrediction( predictions: Seq[Double], input: Seq[LabeledPoint], - expectedAcc: Double = 0.83) { + expectedAcc: Double = 0.83): Unit = { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => prediction != expected.label } @@ -224,12 +224,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val lr = new LogisticRegressionWithSGD().setIntercept(true) - lr.optimizer - .setStepSize(10.0) - .setRegParam(0.0) - .setNumIterations(20) - .setConvergenceTol(0.0005) + val lr = new LogisticRegressionWithSGD(10.0, 20, 0.0, 1.0).setIntercept(true) + lr.optimizer.setConvergenceTol(0.0005) val model = lr.run(testRDD) @@ -300,11 +296,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w testRDD.cache() // Use half as many iterations as the previous test. - val lr = new LogisticRegressionWithSGD().setIntercept(true) - lr.optimizer - .setStepSize(10.0) - .setRegParam(0.0) - .setNumIterations(10) + val lr = new LogisticRegressionWithSGD(10.0, 10, 0.0, 1.0).setIntercept(true) val model = lr.run(testRDD, initialWeights) @@ -335,11 +327,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w testRDD.cache() // Use half as many iterations as the previous test. - val lr = new LogisticRegressionWithSGD().setIntercept(true) - lr.optimizer. - setStepSize(1.0). - setNumIterations(10). - setRegParam(1.0) + val lr = new LogisticRegressionWithSGD(1.0, 10, 1.0, 1.0).setIntercept(true) val model = lr.run(testRDD, initialWeights) @@ -916,7 +904,7 @@ class LogisticRegressionClusterSuite extends SparkFunSuite with LocalClusterSpar }.cache() // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. - val model = LogisticRegressionWithSGD.train(points, 2) + val model = new LogisticRegressionWithSGD(1.0, 2, 0.0, 1.0).run(points) val predictions = model.predict(points.map(_.features)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 725389813b3e2..47dac3ec29a5c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -91,7 +91,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { import NaiveBayes.{Multinomial, Bernoulli} - def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]): Unit = { val numOfPredictions = predictions.zip(input).count { case (prediction, expected) => prediction != expected.label diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 3676d9c5debc8..007b8ae6e1a6a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -62,7 +62,7 @@ object SVMSuite { class SVMSuite extends SparkFunSuite with MLlibTestSparkContext { - def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]): Unit = { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => prediction != expected.label } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index 5f797a60f09e6..7349e0319324a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -23,23 +23,17 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.{LocalStreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream -class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase { +class StreamingLogisticRegressionSuite + extends SparkFunSuite + with LocalStreamingContext + with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 30000 - var ssc: StreamingContext = _ - - override def afterFunction() { - super.afterFunction() - if (ssc != null) { - ssc.stop() - } - } - // Test if we can accurately learn B for Y = logistic(BX) on streaming data test("parameter accuracy") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index c4bf5b27187f6..149a525a58ff6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -367,7 +367,7 @@ class KMeansClusterSuite extends SparkFunSuite with LocalClusterSparkContext { for (initMode <- Seq(KMeans.RANDOM, KMeans.K_MEANS_PARALLEL)) { // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. - val model = KMeans.train(points, 2, 2, 1, initMode) + val model = KMeans.train(points, 2, 2, initMode) val predictions = model.predict(points).collect() val cost = model.computeCost(points) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index a1ac10c06c697..415ac87275390 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -20,23 +20,14 @@ package org.apache.spark.mllib.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.{LocalStreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.random.XORShiftRandom -class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { +class StreamingKMeansSuite extends SparkFunSuite with LocalStreamingContext with TestSuiteBase { override def maxWaitTimeMillis: Int = 30000 - var ssc: StreamingContext = _ - - override def afterFunction() { - super.afterFunction() - if (ssc != null) { - ssc.stop() - } - } - test("accuracy for single center and equivalence to grand average") { // set parameters val numBatches = 10 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index b4520d42fedf5..184c89c9eaaf9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import breeze.linalg.{CSCMatrix, Matrix => BM} import org.mockito.Mockito.when -import org.scalatest.mockito.MockitoSugar._ +import org.scalatestplus.mockito.MockitoSugar._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.config.Kryo._ @@ -39,7 +39,7 @@ class MatricesSuite extends SparkFunSuite { val ser = new KryoSerializer(conf).newInstance() - def check[T: ClassTag](t: T) { + def check[T: ClassTag](t: T): Unit = { assert(ser.deserialize[T](ser.serialize(t)) === t) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index fee0b02bf8ed8..c0c5c5c7d98d5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -42,7 +42,7 @@ class VectorsSuite extends SparkFunSuite with Logging { conf.set(KRYO_REGISTRATION_REQUIRED, true) val ser = new KryoSerializer(conf).newInstance() - def check[T: ClassTag](t: T) { + def check[T: ClassTag](t: T): Unit = { assert(ser.deserialize[T](ser.serialize(t)) === t) } @@ -510,4 +510,27 @@ class VectorsSuite extends SparkFunSuite with Logging { Vectors.sparse(-1, Array((1, 2.0))) } } + + test("dot product only supports vectors of same size") { + val vSize4 = Vectors.dense(arr) + val vSize1 = Vectors.zeros(1) + intercept[IllegalArgumentException]{ vSize1.dot(vSize4) } + } + + test("dense vector dot product") { + val dv = Vectors.dense(arr) + assert(dv.dot(dv) === 0.26) + } + + test("sparse vector dot product") { + val sv = Vectors.sparse(n, indices, values) + assert(sv.dot(sv) === 0.26) + } + + test("mixed sparse and dense vector dot product") { + val sv = Vectors.sparse(n, indices, values) + val dv = Vectors.dense(arr) + assert(sv.dot(dv) === 0.26) + assert(dv.dot(sv) === 0.26) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index f6a996940291c..9d7177e0a149e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -35,7 +35,7 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val numPartitions = 3 var gridBasedMat: BlockMatrix = _ - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() val blocks: Seq[((Int, Int), Matrix)] = Seq( ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index 37d75103d18d2..d197f06a393e8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -29,7 +29,7 @@ class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val n = 4 var mat: CoordinateMatrix = _ - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() val entries = sc.parallelize(Seq( (0, 0, 1.0), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index cca4eb4e4260e..e961d10711860 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -36,7 +36,7 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { ).map(x => IndexedRow(x._1, x._2)) var indexedRows: RDD[IndexedRow] = _ - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() indexedRows = sc.parallelize(data, 2) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index a0c4c68243e67..0a4b11935580a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -57,7 +57,7 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { var denseMat: RowMatrix = _ var sparseMat: RowMatrix = _ - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() denseMat = new RowMatrix(sc.parallelize(denseData, 2)) sparseMat = new RowMatrix(sc.parallelize(sparseData, 2)) @@ -213,7 +213,7 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { brzNorm(v, 1.0) < 1e-6 } - def assertColumnEqualUpToSign(A: BDM[Double], B: BDM[Double], k: Int) { + def assertColumnEqualUpToSign(A: BDM[Double], B: BDM[Double], k: Int): Unit = { assert(A.rows === B.rows) for (j <- 0 until k) { val aj = A(::, j) @@ -338,7 +338,7 @@ class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext var mat: RowMatrix = _ - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() val m = 4 val n = 200000 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala index b3bf5a2a8f2cc..a629c6951abcd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.util.StatCounter class RandomDataGeneratorSuite extends SparkFunSuite { - def apiChecks(gen: RandomDataGenerator[Double]) { + def apiChecks(gen: RandomDataGenerator[Double]): Unit = { // resetting seed should generate the same sequence of random numbers gen.setSeed(42L) val array1 = (0 until 1000).map(_ => gen.nextValue()) @@ -56,7 +56,7 @@ class RandomDataGeneratorSuite extends SparkFunSuite { def distributionChecks(gen: RandomDataGenerator[Double], mean: Double = 0.0, stddev: Double = 1.0, - epsilon: Double = 0.01) { + epsilon: Double = 0.01): Unit = { for (seed <- 0 until 5) { gen.setSeed(seed.toLong) val sample = (0 until 100000).map { _ => gen.nextValue()} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 9b4dc29d326a1..470e1016dab39 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -38,7 +38,7 @@ class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Seri expectedNumPartitions: Int, expectedMean: Double, expectedStddev: Double, - epsilon: Double = 0.01) { + epsilon: Double = 0.01): Unit = { val stats = rdd.stats() assert(expectedSize === stats.count) assert(expectedNumPartitions === rdd.partitions.size) @@ -53,7 +53,7 @@ class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Seri expectedNumPartitions: Int, expectedMean: Double, expectedStddev: Double, - epsilon: Double = 0.01) { + epsilon: Double = 0.01): Unit = { assert(expectedNumPartitions === rdd.partitions.size) val values = new ArrayBuffer[Double]() rdd.collect.foreach { vector => { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index b08ad99f4f204..9be87db873dad 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -224,7 +224,7 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { negativeWeights: Boolean = false, numUserBlocks: Int = -1, numProductBlocks: Int = -1, - negativeFactors: Boolean = true) { + negativeFactors: Boolean = true): Unit = { // scalastyle:on val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index d96103d01e4ab..f336dac0ccb5d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -33,7 +33,7 @@ private object LassoSuite { class LassoSuite extends SparkFunSuite with MLlibTestSparkContext { - def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]): Unit = { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => // A prediction is off if the prediction is more than 0.5 away from expected value. math.abs(prediction - expected.label) > 0.5 @@ -55,8 +55,7 @@ class LassoSuite extends SparkFunSuite with MLlibTestSparkContext { } val testRDD = sc.parallelize(testData, 2).cache() - val ls = new LassoWithSGD() - ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40) + val ls = new LassoWithSGD(1.0, 40, 0.01, 1.0) val model = ls.run(testRDD) val weight0 = model.weights(0) @@ -99,8 +98,8 @@ class LassoSuite extends SparkFunSuite with MLlibTestSparkContext { val testRDD = sc.parallelize(testData, 2).cache() - val ls = new LassoWithSGD() - ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40).setConvergenceTol(0.0005) + val ls = new LassoWithSGD(1.0, 40, 0.01, 1.0) + ls.optimizer.setConvergenceTol(0.0005) val model = ls.run(testRDD, initialWeights) val weight0 = model.weights(0) @@ -153,7 +152,7 @@ class LassoClusterSuite extends SparkFunSuite with LocalClusterSparkContext { }.cache() // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. - val model = LassoWithSGD.train(points, 2) + val model = new LassoWithSGD(1.0, 2, 0.01, 1.0).run(points) val predictions = model.predict(points.map(_.features)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 0694079b9df9e..be0834d0fd7df 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -33,7 +33,7 @@ private object LinearRegressionSuite { class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { - def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]): Unit = { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => // A prediction is off if the prediction is more than 0.5 away from expected value. math.abs(prediction - expected.label) > 0.5 @@ -46,7 +46,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { test("linear regression") { val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput( 3.0, Array(10.0, 10.0), 100, 42), 2).cache() - val linReg = new LinearRegressionWithSGD().setIntercept(true) + val linReg = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0).setIntercept(true) linReg.optimizer.setNumIterations(1000).setStepSize(1.0) val model = linReg.run(testRDD) @@ -72,7 +72,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { test("linear regression without intercept") { val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput( 0.0, Array(10.0, 10.0), 100, 42), 2).cache() - val linReg = new LinearRegressionWithSGD().setIntercept(false) + val linReg = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0).setIntercept(false) linReg.optimizer.setNumIterations(1000).setStepSize(1.0) val model = linReg.run(testRDD) @@ -103,7 +103,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1)))) LabeledPoint(label, sv) }.cache() - val linReg = new LinearRegressionWithSGD().setIntercept(false) + val linReg = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0).setIntercept(false) linReg.optimizer.setNumIterations(1000).setStepSize(1.0) val model = linReg.run(sparseRDD) @@ -160,7 +160,7 @@ class LinearRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkC }.cache() // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. - val model = LinearRegressionWithSGD.train(points, 2) + val model = new LinearRegressionWithSGD(1.0, 2, 0.0, 1.0).run(points) val predictions = model.predict(points.map(_.features)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 815be32d2e510..2d6aec184ad9d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -60,18 +60,13 @@ class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val validationRDD = sc.parallelize(validationData, 2).cache() // First run without regularization. - val linearReg = new LinearRegressionWithSGD() - linearReg.optimizer.setNumIterations(200) - .setStepSize(1.0) + val linearReg = new LinearRegressionWithSGD(1.0, 200, 0.0, 1.0) val linearModel = linearReg.run(testRDD) val linearErr = predictionError( linearModel.predict(validationRDD.map(_.features)).collect(), validationData) - val ridgeReg = new RidgeRegressionWithSGD() - ridgeReg.optimizer.setNumIterations(200) - .setRegParam(0.1) - .setStepSize(1.0) + val ridgeReg = new RidgeRegressionWithSGD(1.0, 200, 0.1, 1.0) val ridgeModel = ridgeReg.run(testRDD) val ridgeErr = predictionError( ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData) @@ -110,7 +105,7 @@ class RidgeRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkCo }.cache() // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. - val model = RidgeRegressionWithSGD.train(points, 2) + val model = new RidgeRegressionWithSGD(1.0, 2, 0.01, 1.0).run(points) val predictions = model.predict(points.map(_.features)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index eaeaa3fc1e68d..8e2d7d10f2ce2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -22,31 +22,25 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LinearDataGenerator -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.{LocalStreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream -class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { +class StreamingLinearRegressionSuite + extends SparkFunSuite + with LocalStreamingContext + with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 20000 - var ssc: StreamingContext = _ - - override def afterFunction() { - super.afterFunction() - if (ssc != null) { - ssc.stop() - } - } - // Assert that two values are equal within tolerance epsilon - def assertEqual(v1: Double, v2: Double, epsilon: Double) { + def assertEqual(v1: Double, v2: Double, epsilon: Double): Unit = { def errorMessage = v1.toString + " did not equal " + v2.toString assert(math.abs(v1-v2) <= epsilon, errorMessage) } // Assert that model predictions are correct - def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]): Unit = { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => // A prediction is off if the prediction is more than 0.5 away from expected value. math.abs(prediction - expected.label) > 0.5 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 88b9d4c039ba9..b738236473230 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -437,7 +437,7 @@ object DecisionTreeSuite extends SparkFunSuite { def validateClassifier( model: DecisionTreeModel, input: Seq[LabeledPoint], - requiredAccuracy: Double) { + requiredAccuracy: Double): Unit = { val predictions = input.map(x => model.predict(x.features)) val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => prediction != expected.label @@ -450,7 +450,7 @@ object DecisionTreeSuite extends SparkFunSuite { def validateRegressor( model: DecisionTreeModel, input: Seq[LabeledPoint], - requiredMSE: Double) { + requiredMSE: Double): Unit = { val predictions = input.map(x => model.predict(x.features)) val squaredError = predictions.zip(input).map { case (prediction, expected) => val err = prediction - expected.label diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala index d43e62bb65535..e04d7b7c327a8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala @@ -37,7 +37,7 @@ object EnsembleTestHelper { numCols: Int, expectedMean: Double, expectedStddev: Double, - epsilon: Double) { + epsilon: Double): Unit = { val values = new mutable.ArrayBuffer[Double]() data.foreach { row => assert(row.size == numCols) @@ -51,7 +51,7 @@ object EnsembleTestHelper { def validateClassifier( model: TreeEnsembleModel, input: Seq[LabeledPoint], - requiredAccuracy: Double) { + requiredAccuracy: Double): Unit = { val predictions = input.map(x => model.predict(x.features)) val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => prediction != expected.label @@ -68,7 +68,7 @@ object EnsembleTestHelper { model: TreeEnsembleModel, input: Seq[LabeledPoint], required: Double, - metricName: String = "mse") { + metricName: String = "mse"): Unit = { val predictions = input.map(x => model.predict(x.features)) val errors = predictions.zip(input).map { case (prediction, point) => point.label - prediction diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index bec61ba6a003c..b1a385a576cea 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.Utils * Test suite for [[RandomForest]]. */ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { - def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) { + def binaryClassificationTestWithContinuousFeatures(strategy: Strategy): Unit = { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) val numTrees = 1 @@ -68,7 +68,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { binaryClassificationTestWithContinuousFeatures(strategy) } - def regressionTestWithContinuousFeatures(strategy: Strategy) { + def regressionTestWithContinuousFeatures(strategy: Strategy): Unit = { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) val numTrees = 1 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala index 2853b752cb85c..79d4785fd6fa7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala @@ -25,7 +25,7 @@ import org.apache.spark.internal.config.Network.RPC_MESSAGE_MAX_SIZE trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() val conf = new SparkConf() .setMaster("local-cluster[2, 1, 1024]") @@ -34,7 +34,7 @@ trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite => sc = new SparkContext(conf) } - override def afterAll() { + override def afterAll(): Unit = { try { if (sc != null) { sc.stop() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 720237bd2dddd..f9a3cd088314e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -31,7 +31,7 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite => @transient var sc: SparkContext = _ @transient var checkpointDir: String = _ - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() spark = SparkSession.builder .master("local[2]") @@ -43,7 +43,7 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite => sc.setCheckpointDir(checkpointDir) } - override def afterAll() { + override def afterAll(): Unit = { try { Utils.deleteRecursively(new File(checkpointDir)) SparkSession.clearActiveSession() diff --git a/pom.xml b/pom.xml index 6c474f5f7a3e7..9c2aa9de85ce6 100644 --- a/pom.xml +++ b/pom.xml @@ -125,7 +125,7 @@ 2.7.4 2.5.0 ${hadoop.version} - 3.4.6 + 3.4.14 2.7.1 0.4.2 org.spark-project.hive @@ -139,7 +139,7 @@ 2.3.0 10.12.1.1 1.10.1 - 1.5.5 + 1.5.6 nohive com.twitter 1.6.0 @@ -164,14 +164,14 @@ 3.4.1 3.2.2 - 2.12.8 + 2.12.10 2.12 --diff --test true 1.9.13 - 2.9.9 - 2.9.9.3 + 2.9.10 + 2.9.10 1.1.7.3 1.1.2 1.10 @@ -240,7 +240,7 @@ --> ${session.executionRootDirectory} - 512m + 1g @@ -620,7 +620,7 @@ com.github.luben zstd-jni - 1.4.2-1 + 1.4.3-1 com.clearspring.analytics @@ -786,14 +786,8 @@ org.scalanlp breeze_${scala.binary.version} - 0.13.2 + 1.0 - - - junit - junit - org.apache.commons commons-math3 @@ -839,7 +833,7 @@ org.scala-lang.modules scala-parser-combinators_${scala.binary.version} - 1.1.0 + 1.1.2 jline @@ -849,7 +843,7 @@ org.scalatest scalatest_${scala.binary.version} - 3.0.5 + 3.0.8 test @@ -867,7 +861,7 @@ org.scalacheck scalacheck_${scala.binary.version} - 1.13.5 + 1.14.2 test @@ -1343,6 +1337,10 @@ io.netty netty + + com.github.spotbugs + spotbugs-annotations + @@ -2002,75 +2000,6 @@ - - ${hive.group} - hive-contrib - ${hive.version} - test - - - ${hive.group} - hive-exec - - - ${hive.group} - hive-serde - - - ${hive.group} - hive-shims - - - commons-codec - commons-codec - - - org.slf4j - slf4j-api - - - - - ${hive.group}.hcatalog - hive-hcatalog-core - ${hive.version} - test - - - ${hive.group} - hive-exec - - - ${hive.group} - hive-metastore - - - ${hive.group} - hive-cli - - - ${hive.group} - hive-common - - - com.google.guava - guava - - - org.slf4j - slf4j-api - - - org.codehaus.jackson - jackson-mapper-asl - - - org.apache.hadoop - * - - - - org.apache.orc orc-core @@ -2287,6 +2216,17 @@ + + enforce-no-duplicate-dependencies + + enforce + + + + + + + @@ -2974,7 +2914,6 @@ 3.2.0 2.13.0 - 3.4.13 org.apache.hive core ${hive23.version} @@ -3053,6 +2992,19 @@ scala-2.12 + + + scala-2.13 + + + + org.scala-lang.modules + scala-parallel-collections_${scala.binary.version} + 0.2.0 + + + + - - ${hive.group} - hive-contrib - - - ${hive.group}.hcatalog - hive-hcatalog-core - org.eclipse.jetty jetty-server diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 36d4ac095e10c..9517a599be633 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -72,7 +72,7 @@ object HiveThriftServer2 extends Logging { server } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { // If the arguments contains "-h" or "--help", print out the usage and exit. if (args.contains("-h") || args.contains("--help")) { HiveServer2.main(args) @@ -303,7 +303,7 @@ private[hive] class HiveThriftServer2(sqlContext: SQLContext) // started, and then once only. private val started = new AtomicBoolean(false) - override def init(hiveConf: HiveConf) { + override def init(hiveConf: HiveConf): Unit = { val sparkSqlCliService = new SparkSQLCLIService(this, sqlContext) setSuperField(this, "cliService", sparkSqlCliService) addService(sparkSqlCliService) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala index 599294dfbb7d7..a4024be67ac9c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.hive.thriftserver private[hive] object ReflectionUtils { - def setSuperField(obj : Object, fieldName: String, fieldValue: Object) { + def setSuperField(obj : Object, fieldName: String, fieldValue: Object): Unit = { setAncestorField(obj, 1, fieldName, fieldValue) } - def setAncestorField(obj: AnyRef, level: Int, fieldName: String, fieldValue: AnyRef) { + def setAncestorField(obj: AnyRef, level: Int, fieldName: String, fieldValue: AnyRef): Unit = { val ancestor = Iterator.iterate[Class[_]](obj.getClass)(_.getSuperclass).drop(level).next() val field = ancestor.getDeclaredField(fieldName) field.setAccessible(true) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 69e85484ccf8e..9ca6c39d016ba 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.HiveResult import org.apache.spark.sql.execution.command.SetCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.{Utils => SparkUtils} private[hive] class SparkExecuteStatementOperation( @@ -77,7 +78,7 @@ private[hive] class SparkExecuteStatementOperation( HiveThriftServer2.listener.onOperationClosed(statementId) } - def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { + def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int): Unit = { dataTypes(ordinal) match { case StringType => to += from.getString(ordinal) @@ -103,6 +104,8 @@ private[hive] class SparkExecuteStatementOperation( to += from.getAs[Timestamp](ordinal) case BinaryType => to += from.getAs[Array[Byte]](ordinal) + case CalendarIntervalType => + to += HiveResult.toHiveString((from.getAs[CalendarInterval](ordinal), CalendarIntervalType)) case _: ArrayType | _: StructType | _: MapType | _: UserDefinedType[_] => val hiveString = HiveResult.toHiveString((from.get(ordinal), dataTypes(ordinal))) to += hiveString @@ -264,6 +267,13 @@ private[hive] class SparkExecuteStatementOperation( // Actually do need to catch Throwable as some failures don't inherit from Exception and // HiveServer will silently swallow them. case e: Throwable => + // When cancel() or close() is called very quickly after the query is started, + // then they may both call cleanup() before Spark Jobs are started. But before background + // task interrupted, it may have start some spark job, so we need to cancel again to + // make sure job was cancelled when background thread was interrupted + if (statementId != null) { + sqlContext.sparkContext.cancelJobGroup(statementId) + } val currentState = getStatus().getState() if (currentState.isTerminal) { // This may happen if the execution was cancelled, and then closed from another thread. @@ -300,7 +310,7 @@ private[hive] class SparkExecuteStatementOperation( } } - private def cleanup(state: OperationState) { + private def cleanup(state: OperationState): Unit = { setState(state) if (runInBackground) { val backgroundHandle = getBackgroundHandle() @@ -331,7 +341,11 @@ private[hive] class SparkExecuteStatementOperation( object SparkExecuteStatementOperation { def getTableSchema(structType: StructType): TableSchema = { val schema = structType.map { field => - val attrTypeString = if (field.dataType == NullType) "void" else field.dataType.catalogString + val attrTypeString = field.dataType match { + case NullType => "void" + case CalendarIntervalType => StringType.catalogString + case other => other.catalogString + } new FieldSchema(field.name, attrTypeString, field.getComment.getOrElse("")) } new TableSchema(schema.asJava) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala new file mode 100644 index 0000000000000..7a6a8c59b7216 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.UUID + +import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType +import org.apache.hive.service.cli.{HiveSQLException, OperationState} +import org.apache.hive.service.cli.operation.GetTypeInfoOperation +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.{Utils => SparkUtils} + +/** + * Spark's own GetTypeInfoOperation + * + * @param sqlContext SQLContext to use + * @param parentSession a HiveSession from SessionManager + */ +private[hive] class SparkGetTypeInfoOperation( + sqlContext: SQLContext, + parentSession: HiveSession) + extends GetTypeInfoOperation(parentSession) with Logging { + + private var statementId: String = _ + + override def close(): Unit = { + super.close() + HiveThriftServer2.listener.onOperationClosed(statementId) + } + + override def runInternal(): Unit = { + statementId = UUID.randomUUID().toString + val logMsg = "Listing type info" + logInfo(s"$logMsg with $statementId") + setState(OperationState.RUNNING) + // Always use the latest class loader provided by executionHive's state. + val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader + Thread.currentThread().setContextClassLoader(executionHiveClassLoader) + + if (isAuthV2Enabled) { + authorizeMetaGets(HiveOperationType.GET_TYPEINFO, null) + } + + HiveThriftServer2.listener.onStatementStart( + statementId, + parentSession.getSessionHandle.getSessionId.toString, + logMsg, + statementId, + parentSession.getUsername) + + try { + ThriftserverShimUtils.supportedType().foreach(typeInfo => { + val rowData = Array[AnyRef]( + typeInfo.getName, // TYPE_NAME + typeInfo.toJavaSQLType.asInstanceOf[AnyRef], // DATA_TYPE + typeInfo.getMaxPrecision.asInstanceOf[AnyRef], // PRECISION + typeInfo.getLiteralPrefix, // LITERAL_PREFIX + typeInfo.getLiteralSuffix, // LITERAL_SUFFIX + typeInfo.getCreateParams, // CREATE_PARAMS + typeInfo.getNullable.asInstanceOf[AnyRef], // NULLABLE + typeInfo.isCaseSensitive.asInstanceOf[AnyRef], // CASE_SENSITIVE + typeInfo.getSearchable.asInstanceOf[AnyRef], // SEARCHABLE + typeInfo.isUnsignedAttribute.asInstanceOf[AnyRef], // UNSIGNED_ATTRIBUTE + typeInfo.isFixedPrecScale.asInstanceOf[AnyRef], // FIXED_PREC_SCALE + typeInfo.isAutoIncrement.asInstanceOf[AnyRef], // AUTO_INCREMENT + typeInfo.getLocalizedName, // LOCAL_TYPE_NAME + typeInfo.getMinimumScale.asInstanceOf[AnyRef], // MINIMUM_SCALE + typeInfo.getMaximumScale.asInstanceOf[AnyRef], // MAXIMUM_SCALE + null, // SQL_DATA_TYPE, unused + null, // SQL_DATETIME_SUB, unused + typeInfo.getNumPrecRadix // NUM_PREC_RADIX + ) + rowSet.addRow(rowData) + }) + setState(OperationState.FINISHED) + } catch { + case e: HiveSQLException => + setState(OperationState.ERROR) + HiveThriftServer2.listener.onStatementError( + statementId, e.getMessage, SparkUtils.exceptionString(e)) + throw e + } + HiveThriftServer2.listener.onStatementFinish(statementId) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index b9614d49eadbd..e3efa2d3ae8c9 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -63,7 +63,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { * a signal handler will invoke this registered callback if a Ctrl+C signal is detected while * a command is being processed by the current thread. */ - def installSignalHandler() { + def installSignalHandler(): Unit = { HiveInterruptUtils.add(() => { // Handle remote execution mode if (SparkSQLEnv.sparkContext != null) { @@ -77,7 +77,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { }) } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val oproc = new OptionsProcessor() if (!oproc.process_stage1(args)) { System.exit(1) @@ -111,6 +111,11 @@ private[hive] object SparkSQLCLIDriver extends Logging { // Set all properties specified via command line. val conf: HiveConf = sessionState.getConf + // Hive 2.0.0 onwards HiveConf.getClassLoader returns the UDFClassLoader (created by Hive). + // Because of this spark cannot find the jars as class loader got changed + // Hive changed the class loader because of HIVE-11878, so it is required to use old + // classLoader as sparks loaded all the jars in this classLoader + conf.setClassLoader(Thread.currentThread().getContextClassLoader) sessionState.cmdProperties.entrySet().asScala.foreach { item => val key = item.getKey.toString val value = item.getValue.toString @@ -133,20 +138,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { // Clean up after we exit ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() } - val remoteMode = isRemoteMode(sessionState) - // "-h" option has been passed, so connect to Hive thrift server. - if (!remoteMode) { - // Hadoop-20 and above - we need to augment classpath using hiveconf - // components. - // See also: code in ExecDriver.java - var loader = conf.getClassLoader - val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS) - if (StringUtils.isNotBlank(auxJars)) { - loader = ThriftserverShimUtils.addToClassPath(loader, StringUtils.split(auxJars, ",")) - } - conf.setClassLoader(loader) - Thread.currentThread().setContextClassLoader(loader) - } else { + if (isRemoteMode(sessionState)) { // Hive 1.2 + not supported in CLI throw new RuntimeException("Remote operations not supported") } @@ -164,6 +156,22 @@ private[hive] object SparkSQLCLIDriver extends Logging { val cli = new SparkSQLCLIDriver cli.setHiveVariables(oproc.getHiveVariables) + // In SparkSQL CLI, we may want to use jars augmented by hiveconf + // hive.aux.jars.path, here we add jars augmented by hiveconf to + // Spark's SessionResourceLoader to obtain these jars. + val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS) + if (StringUtils.isNotBlank(auxJars)) { + val resourceLoader = SparkSQLEnv.sqlContext.sessionState.resourceLoader + StringUtils.split(auxJars, ",").foreach(resourceLoader.addJar(_)) + } + + // The class loader of CliSessionState's conf is current main thread's class loader + // used to load jars passed by --jars. One class loader used by AddJarCommand is + // sharedState.jarClassLoader which contain jar path passed by --jars in main thread. + // We set CliSessionState's conf class loader to sharedState.jarClassLoader. + // Thus we can load all jars passed by --jars and AddJarCommand. + sessionState.getConf.setClassLoader(SparkSQLEnv.sqlContext.sharedState.jarClassLoader) + // TODO work around for set the log output to console, because the HiveContext // will set the output into an invalid buffer. sessionState.in = System.in diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index c32d908ad1bba..1644ecb2453be 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -43,7 +43,7 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLC extends CLIService(hiveServer) with ReflectedCompositeService { - override def init(hiveConf: HiveConf) { + override def init(hiveConf: HiveConf): Unit = { setSuperField(this, "hiveConf", hiveConf) val sparkSqlSessionManager = new SparkSQLSessionManager(hiveServer, sqlContext) @@ -105,7 +105,7 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLC } private[thriftserver] trait ReflectedCompositeService { this: AbstractService => - def initCompositeService(hiveConf: HiveConf) { + def initCompositeService(hiveConf: HiveConf): Unit = { // Emulating `CompositeService.init(hiveConf)` val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList") serviceList.asScala.foreach(_.init(hiveConf)) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 960fdd11db15d..362ac362e9718 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -94,7 +94,7 @@ private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlCont override def getSchema: Schema = tableSchema - override def destroy() { + override def destroy(): Unit = { super.destroy() hiveResponse = null tableSchema = null diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 674da18ca1803..2fda9d0a4f60f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -33,7 +33,7 @@ private[hive] object SparkSQLEnv extends Logging { var sqlContext: SQLContext = _ var sparkContext: SparkContext = _ - def init() { + def init(): Unit = { if (sqlContext == null) { val sparkConf = new SparkConf(loadDefaults = true) // If user doesn't specify the appName, we want to get [SparkSQL::localHostName] instead of @@ -60,7 +60,7 @@ private[hive] object SparkSQLEnv extends Logging { } /** Cleans up and shuts down the Spark SQL environments. */ - def stop() { + def stop(): Unit = { logDebug("Shutting down Spark SQL Environment") // Stop the SparkContext if (SparkSQLEnv.sparkContext != null) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 13055e0ae1394..c4248bfde38cc 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -38,7 +38,7 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: private lazy val sparkSqlOperationManager = new SparkSQLOperationManager() - override def init(hiveConf: HiveConf) { + override def init(hiveConf: HiveConf): Unit = { setSuperField(this, "operationManager", sparkSqlOperationManager) super.init(hiveConf) } @@ -63,6 +63,9 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: sqlContext.newSession() } ctx.setConf(HiveUtils.FAKE_HIVE_VERSION.key, HiveUtils.builtinHiveVersion) + val hiveSessionState = session.getSessionState + setConfMap(ctx, hiveSessionState.getOverriddenConfigurations) + setConfMap(ctx, hiveSessionState.getHiveVariables) if (sessionConf != null && sessionConf.containsKey("use:database")) { ctx.sql(s"use ${sessionConf.get("use:database")}") } @@ -70,10 +73,18 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: sessionHandle } - override def closeSession(sessionHandle: SessionHandle) { + override def closeSession(sessionHandle: SessionHandle): Unit = { HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) super.closeSession(sessionHandle) sparkSqlOperationManager.sessionToActivePool.remove(sessionHandle) sparkSqlOperationManager.sessionToContexts.remove(sessionHandle) } + + def setConfMap(conf: SQLContext, confMap: java.util.Map[String, String]): Unit = { + val iterator = confMap.entrySet().iterator() + while (iterator.hasNext) { + val kv = iterator.next() + conf.setConf(kv.getKey, kv.getValue) + } + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 35f92547e7815..3396560f43502 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -28,7 +28,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.thriftserver._ -import org.apache.spark.sql.internal.SQLConf /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -51,9 +50,6 @@ private[thriftserver] class SparkSQLOperationManager() require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + s" initialized or had already closed.") val conf = sqlContext.sessionState.conf - val hiveSessionState = parentSession.getSessionState - setConfMap(conf, hiveSessionState.getOverriddenConfigurations) - setConfMap(conf, hiveSessionState.getHiveVariables) val runInBackground = async && conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC) val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground)(sqlContext, sessionToActivePool) @@ -145,11 +141,14 @@ private[thriftserver] class SparkSQLOperationManager() operation } - def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = { - val iterator = confMap.entrySet().iterator() - while (iterator.hasNext) { - val kv = iterator.next() - conf.setConfString(kv.getKey, kv.getValue) - } + override def newGetTypeInfoOperation( + parentSession: HiveSession): GetTypeInfoOperation = synchronized { + val sqlContext = sessionToContexts.get(parentSession.getSessionHandle) + require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + + " initialized or had already closed.") + val operation = new SparkGetTypeInfoOperation(sqlContext, parentSession) + handleToOperation.put(operation.getHandle, operation) + logDebug(s"Created GetTypeInfoOperation with session=$parentSession.") + operation } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 261e8fc912eb9..4056be4769d21 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -26,6 +26,7 @@ import org.apache.commons.text.StringEscapeUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, ExecutionState, SessionInfo} +import org.apache.spark.sql.hive.thriftserver.ui.ToolTips._ import org.apache.spark.ui._ import org.apache.spark.ui.UIUtils._ @@ -72,6 +73,10 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Close Time", "Execution Time", "Duration", "Statement", "State", "Detail") + val tooltips = Seq(None, None, None, None, Some(THRIFT_SERVER_FINISH_TIME), + Some(THRIFT_SERVER_CLOSE_TIME), Some(THRIFT_SERVER_EXECUTION), + Some(THRIFT_SERVER_DURATION), None, None, None) + assert(headerRow.length == tooltips.length) val dataRows = listener.getExecutionList.sortBy(_.startTimestamp).reverse def generateDataRow(info: ExecutionInfo): Seq[Node] = { @@ -91,8 +96,10 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" {formatDate(info.startTimestamp)} {if (info.finishTimestamp > 0) formatDate(info.finishTimestamp)} {if (info.closeTimestamp > 0) formatDate(info.closeTimestamp)} - {formatDurationOption(Some(info.totalTime(info.finishTimestamp)))} - {formatDurationOption(Some(info.totalTime(info.closeTimestamp)))} + + {formatDurationOption(Some(info.totalTime(info.finishTimestamp)))} + + {formatDurationOption(Some(info.totalTime(info.closeTimestamp)))} {info.statement} {info.state} {errorMessageCell(detail)} @@ -100,7 +107,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } Some(UIUtils.listingTable(headerRow, generateDataRow, - dataRows, false, None, Seq(null), false)) + dataRows, false, None, Seq(null), false, tooltipHeaders = tooltips)) } else { None } @@ -157,7 +164,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" {session.sessionId} {formatDate(session.startTimestamp)} {if (session.finishTimestamp > 0) formatDate(session.finishTimestamp)} - {formatDurationOption(Some(session.totalTime))} + + {formatDurationOption(Some(session.totalTime))} {session.totalExecution.toString} } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 81df1304085e8..0aa0a2b8335d8 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -26,6 +26,7 @@ import org.apache.commons.text.StringEscapeUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, ExecutionState} +import org.apache.spark.sql.hive.thriftserver.ui.ToolTips._ import org.apache.spark.ui._ import org.apache.spark.ui.UIUtils._ @@ -81,6 +82,10 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Close Time", "Execution Time", "Duration", "Statement", "State", "Detail") + val tooltips = Seq(None, None, None, None, Some(THRIFT_SERVER_FINISH_TIME), + Some(THRIFT_SERVER_CLOSE_TIME), Some(THRIFT_SERVER_EXECUTION), + Some(THRIFT_SERVER_DURATION), None, None, None) + assert(headerRow.length == tooltips.length) val dataRows = executionList.sortBy(_.startTimestamp).reverse def generateDataRow(info: ExecutionInfo): Seq[Node] = { @@ -98,10 +103,14 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) {info.groupId} {formatDate(info.startTimestamp)} - {formatDate(info.finishTimestamp)} - {formatDate(info.closeTimestamp)} - {formatDurationOption(Some(info.totalTime(info.finishTimestamp)))} - {formatDurationOption(Some(info.totalTime(info.closeTimestamp)))} + {if (info.finishTimestamp > 0) formatDate(info.finishTimestamp)} + {if (info.closeTimestamp > 0) formatDate(info.closeTimestamp)} + + {formatDurationOption(Some(info.totalTime(info.finishTimestamp)))} + + + {formatDurationOption(Some(info.totalTime(info.closeTimestamp)))} + {info.statement} {info.state} {errorMessageCell(detail)} @@ -109,7 +118,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) } Some(UIUtils.listingTable(headerRow, generateDataRow, - dataRows, false, None, Seq(null), false)) + dataRows, false, None, Seq(null), false, tooltipHeaders = tooltips)) } else { None } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala index db2066009b351..8efb2c3311cfe 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -39,7 +39,7 @@ private[thriftserver] class ThriftServerTab(sparkContext: SparkContext) attachPage(new ThriftServerSessionPage(this)) parent.attachTab(this) - def detach() { + def detach(): Unit = { getSparkUI(sparkContext).detachTab(this) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ToolTips.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ToolTips.scala new file mode 100644 index 0000000000000..1990b8f2d3285 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ToolTips.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver.ui + +private[ui] object ToolTips { + val THRIFT_SERVER_FINISH_TIME = + "Execution finish time, before fetching the results" + + val THRIFT_SERVER_CLOSE_TIME = + "Operation close time after fetching the results" + + val THRIFT_SERVER_EXECUTION = + "Difference between start time and finish time" + + val THRIFT_SERVER_DURATION = + "Difference between start time and close time" +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 6e042ac41d9da..f3063675a79f7 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -27,12 +27,11 @@ import scala.concurrent.Promise import scala.concurrent.duration._ import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.HiveTestJars import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.util.{ThreadUtils, Utils} @@ -202,7 +201,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { } test("Commands using SerDe provided in --jars") { - val jarFile = HiveTestUtils.getHiveHcatalogCoreJar.getCanonicalPath + val jarFile = HiveTestJars.getHiveHcatalogCoreJar().getCanonicalPath val dataFilePath = Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") @@ -218,8 +217,8 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { -> "", "INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;" -> "", - "SELECT count(key) FROM t1;" - -> "5", + "SELECT collect_list(array(val)) FROM t1;" + -> """[["val_238"],["val_86"],["val_311"],["val_27"],["val_165"]]""", "DROP TABLE t1;" -> "", "DROP TABLE sourceTable;" @@ -227,6 +226,32 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { ) } + test("SPARK-29022: Commands using SerDe provided in --hive.aux.jars.path") { + val dataFilePath = + Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") + val hiveContribJar = HiveTestJars.getHiveHcatalogCoreJar().getCanonicalPath + runCliWithin( + 3.minute, + Seq("--conf", s"spark.hadoop.${ConfVars.HIVEAUXJARS}=$hiveContribJar"))( + """CREATE TABLE addJarWithHiveAux(key string, val string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'; + """.stripMargin + -> "", + "CREATE TABLE sourceTableForWithHiveAux (key INT, val STRING);" + -> "", + s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTableForWithHiveAux;" + -> "", + "INSERT INTO TABLE addJarWithHiveAux SELECT key, val FROM sourceTableForWithHiveAux;" + -> "", + "SELECT collect_list(array(val)) FROM addJarWithHiveAux;" + -> """[["val_238"],["val_86"],["val_311"],["val_27"],["val_165"]]""", + "DROP TABLE addJarWithHiveAux;" + -> "", + "DROP TABLE sourceTableForWithHiveAux;" + -> "" + ) + } + test("SPARK-11188 Analysis error reporting") { runCliWithin(timeout = 2.minute, errorResponses = Seq("AnalysisException"))( @@ -297,12 +322,66 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { } test("Support hive.aux.jars.path") { - val hiveContribJar = HiveTestUtils.getHiveContribJar.getCanonicalPath + val hiveContribJar = HiveTestJars.getHiveContribJar().getCanonicalPath runCliWithin( 1.minute, Seq("--conf", s"spark.hadoop.${ConfVars.HIVEAUXJARS}=$hiveContribJar"))( - s"CREATE TEMPORARY FUNCTION example_max AS '${classOf[UDAFExampleMax].getName}';" -> "", - "SELECT example_max(1);" -> "1" + "CREATE TEMPORARY FUNCTION example_format AS " + + "'org.apache.hadoop.hive.contrib.udf.example.UDFExampleFormat';" -> "", + "SELECT example_format('%o', 93);" -> "135" + ) + } + + test("SPARK-28840 test --jars command") { + val jarFile = new File("../../sql/hive/src/test/resources/SPARK-21101-1.0.jar").getCanonicalPath + runCliWithin( + 1.minute, + Seq("--jars", s"$jarFile"))( + "CREATE TEMPORARY FUNCTION testjar AS" + + " 'org.apache.spark.sql.hive.execution.UDTFStack';" -> "", + "SELECT testjar(1,'TEST-SPARK-TEST-jar', 28840);" -> "TEST-SPARK-TEST-jar\t28840" + ) + } + + test("SPARK-28840 test --jars and hive.aux.jars.path command") { + val jarFile = new File("../../sql/hive/src/test/resources/SPARK-21101-1.0.jar").getCanonicalPath + val hiveContribJar = HiveTestJars.getHiveContribJar().getCanonicalPath + runCliWithin( + 1.minute, + Seq("--jars", s"$jarFile", "--conf", + s"spark.hadoop.${ConfVars.HIVEAUXJARS}=$hiveContribJar"))( + "CREATE TEMPORARY FUNCTION testjar AS" + + " 'org.apache.spark.sql.hive.execution.UDTFStack';" -> "", + "SELECT testjar(1,'TEST-SPARK-TEST-jar', 28840);" -> "TEST-SPARK-TEST-jar\t28840", + "CREATE TEMPORARY FUNCTION example_max AS " + + "'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax';" -> "", + "SELECT concat_ws(',', 'First', example_max(1234321), 'Third');" -> "First,1234321,Third" + ) + } + + test("SPARK-29022 Commands using SerDe provided in ADD JAR sql") { + val dataFilePath = + Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") + val hiveContribJar = HiveTestJars.getHiveHcatalogCoreJar().getCanonicalPath + runCliWithin( + 3.minute)( + s"ADD JAR ${hiveContribJar};" -> "", + """CREATE TABLE addJarWithSQL(key string, val string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'; + """.stripMargin + -> "", + "CREATE TABLE sourceTableForWithSQL(key INT, val STRING);" + -> "", + s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTableForWithSQL;" + -> "", + "INSERT INTO TABLE addJarWithSQL SELECT key, val FROM sourceTableForWithSQL;" + -> "", + "SELECT collect_list(array(val)) FROM addJarWithSQL;" + -> """[["val_238"],["val_86"],["val_311"],["val_27"],["val_165"]]""", + "DROP TABLE addJarWithSQL;" + -> "", + "DROP TABLE sourceTableForWithSQL;" + -> "" ) } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index b7185db2f2ae7..8a5526ea780ef 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -43,7 +43,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.sql.hive.test.HiveTestUtils +import org.apache.spark.sql.hive.test.HiveTestJars import org.apache.spark.sql.internal.StaticSQLConf.HIVE_THRIFT_SERVER_SINGLESESSION import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.util.{ThreadUtils, Utils} @@ -144,10 +144,17 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { def executeTest(hiveList: String): Unit = { hiveList.split(";").foreach{ m => val kv = m.split("=") - // select "${a}"; ---> avalue - val resultSet = statement.executeQuery("select \"${" + kv(0) + "}\"") + val k = kv(0) + val v = kv(1) + val modValue = s"${v}_MOD_VALUE" + // select '${a}'; ---> avalue + val resultSet = statement.executeQuery(s"select '$${$k}'") resultSet.next() - assert(resultSet.getString(1) === kv(1)) + assert(resultSet.getString(1) === v) + statement.executeQuery(s"set $k=$modValue") + val modResultSet = statement.executeQuery(s"select '$${$k}'") + modResultSet.next() + assert(modResultSet.getString(1) === s"$modValue") } } } @@ -485,7 +492,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { withMultipleConnectionJdbcStatement("smallKV", "addJar")( { statement => - val jarFile = HiveTestUtils.getHiveHcatalogCoreJar.getCanonicalPath + val jarFile = HiveTestJars.getHiveHcatalogCoreJar().getCanonicalPath statement.executeQuery(s"ADD JAR $jarFile") }, @@ -662,6 +669,21 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { assert(rs.getBigDecimal(1) === new java.math.BigDecimal("1.000000000000000000")) } } + + test("Support interval type") { + withJdbcStatement() { statement => + val rs = statement.executeQuery("SELECT interval 3 months 1 hours") + assert(rs.next()) + assert(rs.getString(1) === "interval 3 months 1 hours") + } + // Invalid interval value + withJdbcStatement() { statement => + val e = intercept[SQLException] { + statement.executeQuery("SELECT interval 3 months 1 hou") + } + assert(e.getMessage.contains("org.apache.spark.sql.catalyst.parser.ParseException")) + } + } } class SingleSessionSuite extends HiveThriftJdbcTest { @@ -820,7 +842,7 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { s"jdbc:hive2://localhost:$serverPort/?${hiveConfList}#${hiveVarList}" } - def withMultipleConnectionJdbcStatement(tableNames: String*)(fs: (Statement => Unit)*) { + def withMultipleConnectionJdbcStatement(tableNames: String*)(fs: (Statement => Unit)*): Unit = { val user = System.getProperty("user.name") val connections = fs.map { _ => DriverManager.getConnection(jdbcUri, user, "") } val statements = connections.map(_.createStatement()) @@ -841,7 +863,7 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { } } - def withDatabase(dbNames: String*)(fs: (Statement => Unit)*) { + def withDatabase(dbNames: String*)(fs: (Statement => Unit)*): Unit = { val user = System.getProperty("user.name") val connections = fs.map { _ => DriverManager.getConnection(jdbcUri, user, "") } val statements = connections.map(_.createStatement()) @@ -857,7 +879,7 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { } } - def withJdbcStatement(tableNames: String*)(f: Statement => Unit) { + def withJdbcStatement(tableNames: String*)(f: Statement => Unit): Unit = { withMultipleConnectionJdbcStatement(tableNames: _*)(f) } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala index 21870ffd463ec..f7ee3e0a46cd1 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala @@ -231,4 +231,20 @@ class SparkMetadataOperationSuite extends HiveThriftJdbcTest { assert(!rs.next()) } } + + test("GetTypeInfo Thrift API") { + def checkResult(rs: ResultSet, typeNames: Seq[String]): Unit = { + for (i <- typeNames.indices) { + assert(rs.next()) + assert(rs.getString("TYPE_NAME") === typeNames(i)) + } + // Make sure there are no more elements + assert(!rs.next()) + } + + withJdbcStatement() { statement => + val metaData = statement.getConnection.getMetaData + checkResult(metaData.getTypeInfo, ThriftserverShimUtils.supportedType().map(_.getName)) + } + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkThriftServerProtocolVersionsSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkThriftServerProtocolVersionsSuite.scala index f198372a4c998..10ec1ee168303 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkThriftServerProtocolVersionsSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkThriftServerProtocolVersionsSuite.scala @@ -261,10 +261,10 @@ class SparkThriftServerProtocolVersionsSuite extends HiveThriftJdbcTest { } } - // We do not fully support interval type - ignore(s"$version get interval type") { + test(s"$version get interval type") { testExecuteStatementWithProtocolVersion(version, "SELECT interval '1' year '2' day") { rs => assert(rs.next()) + assert(rs.getString(1) === "interval 1 years 2 days") } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala index 1f7b3feae47b5..613c1655727bb 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala @@ -18,19 +18,21 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File -import java.sql.{DriverManager, SQLException, Statement, Timestamp} -import java.util.Locale +import java.sql.{DriverManager, Statement, Timestamp} +import java.util.{Locale, MissingFormatArgumentException} import scala.util.{Random, Try} import scala.util.control.NonFatal +import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hive.service.cli.HiveSQLException -import org.scalatest.Ignore +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, SQLQueryTestSuite} +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.util.fileToString import org.apache.spark.sql.execution.HiveResult +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -43,12 +45,12 @@ import org.apache.spark.sql.types._ * 2. Support DESC command. * 3. Support SHOW command. */ -@Ignore class ThriftServerQueryTestSuite extends SQLQueryTestSuite { private var hiveServer2: HiveThriftServer2 = _ - override def beforeEach(): Unit = { + override def beforeAll(): Unit = { + super.beforeAll() // Chooses a random port between 10000 and 19999 var listeningPort = 10000 + Random.nextInt(10000) @@ -65,36 +67,40 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { logInfo("HiveThriftServer2 started successfully") } - override def afterEach(): Unit = { - hiveServer2.stop() + override def afterAll(): Unit = { + try { + hiveServer2.stop() + } finally { + super.afterAll() + } } + override def sparkConf: SparkConf = super.sparkConf + // Hive Thrift server should not executes SQL queries in an asynchronous way + // because we may set session configuration. + .set(HiveUtils.HIVE_THRIFT_SERVER_ASYNC, false) + override val isTestWithConfigSets = false /** List of test cases to ignore, in lower cases. */ override def blackList: Set[String] = Set( "blacklist.sql", // Do NOT remove this one. It is here to test the blacklist functionality. // Missing UDF - "pgSQL/boolean.sql", - "pgSQL/case.sql", + "postgreSQL/boolean.sql", + "postgreSQL/case.sql", // SPARK-28624 "date.sql", - // SPARK-28619 - "pgSQL/aggregates_part1.sql", - "group-by.sql", // SPARK-28620 - "pgSQL/float4.sql", + "postgreSQL/float4.sql", // SPARK-28636 "decimalArithmeticOperations.sql", "literals.sql", "subquery/scalar-subquery/scalar-subquery-predicate.sql", "subquery/in-subquery/in-limit.sql", + "subquery/in-subquery/in-group-by.sql", "subquery/in-subquery/simple-in.sql", "subquery/in-subquery/in-order-by.sql", - "subquery/in-subquery/in-set-operations.sql", - // SPARK-28637 - "cast.sql", - "ansi/interval.sql" + "subquery/in-subquery/in-set-operations.sql" ) override def runQueries( @@ -110,8 +116,8 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { case _: PgSQLTest => // PostgreSQL enabled cartesian product by default. statement.execute(s"SET ${SQLConf.CROSS_JOINS_ENABLED.key} = true") - statement.execute(s"SET ${SQLConf.ANSI_SQL_PARSER.key} = true") - statement.execute(s"SET ${SQLConf.PREFER_INTEGRAL_DIVISION.key} = true") + statement.execute(s"SET ${SQLConf.ANSI_ENABLED.key} = true") + statement.execute(s"SET ${SQLConf.DIALECT.key} = ${SQLConf.Dialect.POSTGRESQL.toString}") case _ => } @@ -166,19 +172,42 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { || d.sql.toUpperCase(Locale.ROOT).startsWith("DESC\n") || d.sql.toUpperCase(Locale.ROOT).startsWith("DESCRIBE ") || d.sql.toUpperCase(Locale.ROOT).startsWith("DESCRIBE\n") => + // Skip show command, see HiveResult.hiveResultString case s if s.sql.toUpperCase(Locale.ROOT).startsWith("SHOW ") || s.sql.toUpperCase(Locale.ROOT).startsWith("SHOW\n") => - // AnalysisException should exactly match. - // SQLException should not exactly match. We only assert the result contains Exception. - case _ if output.output.startsWith(classOf[SQLException].getName) => + + case _ if output.output.startsWith(classOf[NoSuchTableException].getPackage.getName) => + assert(expected.output.startsWith(classOf[NoSuchTableException].getPackage.getName), + s"Exception did not match for query #$i\n${expected.sql}, " + + s"expected: ${expected.output}, but got: ${output.output}") + + case _ if output.output.startsWith(classOf[SparkException].getName) && + output.output.contains("overflow") => + assert(expected.output.contains(classOf[ArithmeticException].getName) && + expected.output.contains("overflow"), + s"Exception did not match for query #$i\n${expected.sql}, " + + s"expected: ${expected.output}, but got: ${output.output}") + + case _ if output.output.startsWith(classOf[RuntimeException].getName) => assert(expected.output.contains("Exception"), s"Exception did not match for query #$i\n${expected.sql}, " + s"expected: ${expected.output}, but got: ${output.output}") - // HiveSQLException is usually a feature that our ThriftServer cannot support. - // Please add SQL to blackList. - case _ if output.output.startsWith(classOf[HiveSQLException].getName) => - assert(false, s"${output.output} for query #$i\n${expected.sql}") + + case _ if output.output.startsWith(classOf[ArithmeticException].getName) && + output.output.contains("causes overflow") => + assert(expected.output.contains(classOf[ArithmeticException].getName) && + expected.output.contains("causes overflow"), + s"Exception did not match for query #$i\n${expected.sql}, " + + s"expected: ${expected.output}, but got: ${output.output}") + + case _ if output.output.startsWith(classOf[MissingFormatArgumentException].getName) && + output.output.contains("Format specifier") => + assert(expected.output.contains(classOf[MissingFormatArgumentException].getName) && + expected.output.contains("Format specifier"), + s"Exception did not match for query #$i\n${expected.sql}, " + + s"expected: ${expected.output}, but got: ${output.output}") + case _ => assertResult(expected.output, s"Result did not match for query #$i\n${expected.sql}") { output.output @@ -209,7 +238,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udf")) { Seq.empty - } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}pgSQL")) { + } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}postgreSQL")) { PgSQLTestCase(testCaseName, absPath, resultFile) :: Nil } else { RegularTestCase(testCaseName, absPath, resultFile) :: Nil @@ -248,8 +277,9 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")).sorted case NonFatal(e) => + val rootCause = ExceptionUtils.getRootCause(e) // If there is an exception, put the exception class followed by the message. - Seq(e.getClass.getName, e.getMessage) + Seq(rootCause.getClass.getName, rootCause.getMessage) } } @@ -260,7 +290,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { hiveServer2 = HiveThriftServer2.startWithContext(sqlContext) } - private def withJdbcStatement(fs: (Statement => Unit)*) { + private def withJdbcStatement(fs: (Statement => Unit)*): Unit = { val user = System.getProperty("user.name") val serverPort = hiveServer2.getHiveConf.get(ConfVars.HIVE_SERVER2_THRIFT_PORT.varname) @@ -337,7 +367,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { upperCase.startsWith("SELECT ") || upperCase.startsWith("SELECT\n") || upperCase.startsWith("WITH ") || upperCase.startsWith("WITH\n") || upperCase.startsWith("VALUES ") || upperCase.startsWith("VALUES\n") || - // pgSQL/union.sql + // postgreSQL/union.sql upperCase.startsWith("(") } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index 47cf4f104d204..7f731f3d05e51 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -24,8 +24,8 @@ import org.openqa.selenium.WebDriver import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest.{BeforeAndAfterAll, Matchers} import org.scalatest.concurrent.Eventually._ -import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ +import org.scalatestplus.selenium.WebBrowser import org.apache.spark.ui.SparkUICssErrorHandler diff --git a/sql/hive-thriftserver/v1.2.1/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java b/sql/hive-thriftserver/v1.2.1/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java index 0f72071d7e7d1..3e81f8afbd85f 100644 --- a/sql/hive-thriftserver/v1.2.1/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java +++ b/sql/hive-thriftserver/v1.2.1/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java @@ -73,7 +73,7 @@ public class GetTypeInfoOperation extends MetadataOperation { .addPrimitiveColumn("NUM_PREC_RADIX", Type.INT_TYPE, "Usually 2 or 10"); - private final RowSet rowSet; + protected final RowSet rowSet; protected GetTypeInfoOperation(HiveSession parentSession) { super(parentSession, OperationType.GET_TYPE_INFO); diff --git a/sql/hive-thriftserver/v1.2.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala b/sql/hive-thriftserver/v1.2.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala index 87c0f8f6a571a..fbfc698ecb4bf 100644 --- a/sql/hive-thriftserver/v1.2.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala +++ b/sql/hive-thriftserver/v1.2.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.hive.thriftserver import org.apache.commons.logging.LogFactory -import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hive.service.cli.{RowSet, RowSetFactory, TableSchema, Type} +import org.apache.hive.service.cli.Type._ import org.apache.hive.service.cli.thrift.TProtocolVersion._ /** @@ -51,10 +51,12 @@ private[thriftserver] object ThriftserverShimUtils { private[thriftserver] def toJavaSQLType(s: String): Int = Type.getType(s).toJavaSQLType - private[thriftserver] def addToClassPath( - loader: ClassLoader, - auxJars: Array[String]): ClassLoader = { - Utilities.addToClassPath(loader, auxJars) + private[thriftserver] def supportedType(): Seq[Type] = { + Seq(NULL_TYPE, BOOLEAN_TYPE, STRING_TYPE, BINARY_TYPE, + TINYINT_TYPE, SMALLINT_TYPE, INT_TYPE, BIGINT_TYPE, + FLOAT_TYPE, DOUBLE_TYPE, DECIMAL_TYPE, + DATE_TYPE, TIMESTAMP_TYPE, + ARRAY_TYPE, MAP_TYPE, STRUCT_TYPE) } private[thriftserver] val testedProtocolVersions = Seq( diff --git a/sql/hive-thriftserver/v2.3.5/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java b/sql/hive-thriftserver/v2.3.5/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java index 9612eb145638c..0f57a72e2a1ce 100644 --- a/sql/hive-thriftserver/v2.3.5/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java +++ b/sql/hive-thriftserver/v2.3.5/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java @@ -73,7 +73,7 @@ public class GetTypeInfoOperation extends MetadataOperation { .addPrimitiveColumn("NUM_PREC_RADIX", Type.INT_TYPE, "Usually 2 or 10"); - private final RowSet rowSet; + protected final RowSet rowSet; protected GetTypeInfoOperation(HiveSession parentSession) { super(parentSession, OperationType.GET_TYPE_INFO); diff --git a/sql/hive-thriftserver/v2.3.5/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala b/sql/hive-thriftserver/v2.3.5/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala index 124c9937c0fca..850382fe2bfd7 100644 --- a/sql/hive-thriftserver/v2.3.5/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala +++ b/sql/hive-thriftserver/v2.3.5/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala @@ -17,13 +17,9 @@ package org.apache.spark.sql.hive.thriftserver -import java.security.AccessController - -import scala.collection.JavaConverters._ - -import org.apache.hadoop.hive.ql.exec.AddToClassPathAction import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.thrift.Type +import org.apache.hadoop.hive.serde2.thrift.Type._ import org.apache.hive.service.cli.{RowSet, RowSetFactory, TableSchema} import org.apache.hive.service.rpc.thrift.TProtocolVersion._ import org.slf4j.LoggerFactory @@ -56,11 +52,12 @@ private[thriftserver] object ThriftserverShimUtils { private[thriftserver] def toJavaSQLType(s: String): Int = Type.getType(s).toJavaSQLType - private[thriftserver] def addToClassPath( - loader: ClassLoader, - auxJars: Array[String]): ClassLoader = { - val addAction = new AddToClassPathAction(loader, auxJars.toList.asJava) - AccessController.doPrivileged(addAction) + private[thriftserver] def supportedType(): Seq[Type] = { + Seq(NULL_TYPE, BOOLEAN_TYPE, STRING_TYPE, BINARY_TYPE, + TINYINT_TYPE, SMALLINT_TYPE, INT_TYPE, BIGINT_TYPE, + FLOAT_TYPE, DOUBLE_TYPE, DECIMAL_TYPE, + DATE_TYPE, TIMESTAMP_TYPE, + ARRAY_TYPE, MAP_TYPE, STRUCT_TYPE) } private[thriftserver] val testedProtocolVersions = Seq( diff --git a/sql/hive/benchmarks/ObjectHashAggregateExecBenchmark-results.txt b/sql/hive/benchmarks/ObjectHashAggregateExecBenchmark-results.txt index f3044da972497..0c394a340333a 100644 --- a/sql/hive/benchmarks/ObjectHashAggregateExecBenchmark-results.txt +++ b/sql/hive/benchmarks/ObjectHashAggregateExecBenchmark-results.txt @@ -2,44 +2,44 @@ Hive UDAF vs Spark AF ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -hive udaf w/o group by 6370 / 6400 0.0 97193.6 1.0X -spark af w/o group by 54 / 63 1.2 820.8 118.4X -hive udaf w/ group by 4492 / 4507 0.0 68539.5 1.4X -spark af w/ group by w/o fallback 58 / 64 1.1 881.7 110.2X -spark af w/ group by w/ fallback 136 / 142 0.5 2075.0 46.8X +hive udaf vs spark af: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +hive udaf w/o group by 6741 6759 22 0.0 102864.5 1.0X +spark af w/o group by 56 66 9 1.2 851.6 120.8X +hive udaf w/ group by 4610 4642 25 0.0 70350.3 1.5X +spark af w/ group by w/o fallback 60 67 8 1.1 916.7 112.2X +spark af w/ group by w/ fallback 135 144 9 0.5 2065.6 49.8X ================================================================================================ ObjectHashAggregateExec vs SortAggregateExec - typed_count ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -object agg v.s. sort agg: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -sort agg w/ group by 41500 / 41630 2.5 395.8 1.0X -object agg w/ group by w/o fallback 10075 / 10122 10.4 96.1 4.1X -object agg w/ group by w/ fallback 28131 / 28205 3.7 268.3 1.5X -sort agg w/o group by 6182 / 6221 17.0 59.0 6.7X -object agg w/o group by w/o fallback 5435 / 5468 19.3 51.8 7.6X +object agg v.s. sort agg: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +sort agg w/ group by 41568 41894 461 2.5 396.4 1.0X +object agg w/ group by w/o fallback 10314 10494 149 10.2 98.4 4.0X +object agg w/ group by w/ fallback 26720 26951 326 3.9 254.8 1.6X +sort agg w/o group by 6638 6681 38 15.8 63.3 6.3X +object agg w/o group by w/o fallback 5665 5706 30 18.5 54.0 7.3X ================================================================================================ ObjectHashAggregateExec vs SortAggregateExec - percentile_approx ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -object agg v.s. sort agg: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -sort agg w/ group by 970 / 1025 2.2 462.5 1.0X -object agg w/ group by w/o fallback 772 / 798 2.7 368.1 1.3X -object agg w/ group by w/ fallback 1013 / 1044 2.1 483.1 1.0X -sort agg w/o group by 751 / 781 2.8 358.0 1.3X -object agg w/o group by w/o fallback 772 / 814 2.7 368.0 1.3X +object agg v.s. sort agg: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +sort agg w/ group by 794 862 33 2.6 378.8 1.0X +object agg w/ group by w/o fallback 605 622 10 3.5 288.5 1.3X +object agg w/ group by w/ fallback 840 860 15 2.5 400.5 0.9X +sort agg w/o group by 555 570 12 3.8 264.6 1.4X +object agg w/o group by w/o fallback 544 562 12 3.9 259.6 1.5X diff --git a/sql/hive/benchmarks/OrcReadBenchmark-results.txt b/sql/hive/benchmarks/OrcReadBenchmark-results.txt index caa78b9a8f102..c47cf27bf617a 100644 --- a/sql/hive/benchmarks/OrcReadBenchmark-results.txt +++ b/sql/hive/benchmarks/OrcReadBenchmark-results.txt @@ -2,155 +2,155 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Native ORC MR 1725 / 1759 9.1 109.7 1.0X -Native ORC Vectorized 272 / 316 57.8 17.3 6.3X -Hive built-in ORC 1970 / 1987 8.0 125.3 0.9X +SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Native ORC MR 1843 1958 162 8.5 117.2 1.0X +Native ORC Vectorized 321 355 31 48.9 20.4 5.7X +Hive built-in ORC 2143 2175 44 7.3 136.3 0.9X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Native ORC MR 1633 / 1672 9.6 103.8 1.0X -Native ORC Vectorized 238 / 255 66.0 15.1 6.9X -Hive built-in ORC 2293 / 2305 6.9 145.8 0.7X +SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Native ORC MR 1987 2020 47 7.9 126.3 1.0X +Native ORC Vectorized 276 299 25 57.0 17.6 7.2X +Hive built-in ORC 2350 2357 10 6.7 149.4 0.8X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Native ORC MR 1677 / 1699 9.4 106.6 1.0X -Native ORC Vectorized 325 / 342 48.3 20.7 5.2X -Hive built-in ORC 2561 / 2569 6.1 162.8 0.7X +SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Native ORC MR 2092 2115 32 7.5 133.0 1.0X +Native ORC Vectorized 360 373 18 43.6 22.9 5.8X +Hive built-in ORC 2550 2557 9 6.2 162.2 0.8X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Native ORC MR 1791 / 1795 8.8 113.9 1.0X -Native ORC Vectorized 400 / 408 39.3 25.4 4.5X -Hive built-in ORC 2713 / 2720 5.8 172.5 0.7X +SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Native ORC MR 2173 2188 21 7.2 138.2 1.0X +Native ORC Vectorized 435 448 14 36.2 27.7 5.0X +Hive built-in ORC 2683 2690 10 5.9 170.6 0.8X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Native ORC MR 1791 / 1805 8.8 113.8 1.0X -Native ORC Vectorized 433 / 438 36.3 27.5 4.1X -Hive built-in ORC 2690 / 2803 5.8 171.0 0.7X +SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Native ORC MR 2233 2323 127 7.0 142.0 1.0X +Native ORC Vectorized 475 483 13 33.1 30.2 4.7X +Hive built-in ORC 2605 2610 6 6.0 165.7 0.9X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Native ORC MR 1911 / 1930 8.2 121.5 1.0X -Native ORC Vectorized 543 / 552 29.0 34.5 3.5X -Hive built-in ORC 2967 / 3065 5.3 188.6 0.6X +SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Native ORC MR 2367 2384 24 6.6 150.5 1.0X +Native ORC Vectorized 600 641 69 26.2 38.1 3.9X +Hive built-in ORC 2860 2877 24 5.5 181.9 0.8X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Native ORC MR 4160 / 4188 2.5 396.7 1.0X -Native ORC Vectorized 2405 / 2406 4.4 229.4 1.7X -Hive built-in ORC 5514 / 5562 1.9 525.9 0.8X +Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Native ORC MR 4253 4330 108 2.5 405.6 1.0X +Native ORC Vectorized 2295 2301 8 4.6 218.9 1.9X +Hive built-in ORC 5364 5465 144 2.0 511.5 0.8X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Data column - Native ORC MR 1863 / 1867 8.4 118.4 1.0X -Data column - Native ORC Vectorized 411 / 418 38.2 26.2 4.5X -Data column - Hive built-in ORC 3297 / 3308 4.8 209.6 0.6X -Partition column - Native ORC MR 1505 / 1506 10.4 95.7 1.2X -Partition column - Native ORC Vectorized 80 / 93 195.6 5.1 23.2X -Partition column - Hive built-in ORC 1960 / 1979 8.0 124.6 1.0X -Both columns - Native ORC MR 2076 / 2090 7.6 132.0 0.9X -Both columns - Native ORC Vectorized 450 / 463 34.9 28.6 4.1X -Both columns - Hive built-in ORC 3528 / 3548 4.5 224.3 0.5X +Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Data column - Native ORC MR 2443 2448 6 6.4 155.3 1.0X +Data column - Native ORC Vectorized 446 473 44 35.3 28.3 5.5X +Data column - Hive built-in ORC 2868 2877 12 5.5 182.4 0.9X +Partition column - Native ORC MR 1623 1656 47 9.7 103.2 1.5X +Partition column - Native ORC Vectorized 112 121 14 140.8 7.1 21.9X +Partition column - Hive built-in ORC 1846 1850 5 8.5 117.4 1.3X +Both columns - Native ORC MR 2610 2635 36 6.0 165.9 0.9X +Both columns - Native ORC Vectorized 492 508 19 32.0 31.3 5.0X +Both columns - Hive built-in ORC 2969 2973 4 5.3 188.8 0.8X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Native ORC MR 1727 / 1733 6.1 164.7 1.0X -Native ORC Vectorized 375 / 379 28.0 35.7 4.6X -Hive built-in ORC 2665 / 2666 3.9 254.2 0.6X +Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Native ORC MR 2056 2064 11 5.1 196.1 1.0X +Native ORC Vectorized 415 421 7 25.3 39.6 5.0X +Hive built-in ORC 2710 2722 17 3.9 258.4 0.8X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Native ORC MR 3324 / 3325 3.2 317.0 1.0X -Native ORC Vectorized 1085 / 1106 9.7 103.4 3.1X -Hive built-in ORC 5272 / 5299 2.0 502.8 0.6X +String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Native ORC MR 3655 3674 27 2.9 348.6 1.0X +Native ORC Vectorized 1166 1167 1 9.0 111.2 3.1X +Hive built-in ORC 5268 5305 52 2.0 502.4 0.7X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan (50.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Native ORC MR 3045 / 3046 3.4 290.4 1.0X -Native ORC Vectorized 1248 / 1260 8.4 119.0 2.4X -Hive built-in ORC 3989 / 3999 2.6 380.4 0.8X +String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Native ORC MR 3447 3467 27 3.0 328.8 1.0X +Native ORC Vectorized 1222 1223 1 8.6 116.6 2.8X +Hive built-in ORC 3947 3959 18 2.7 376.4 0.9X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan (95.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Native ORC MR 1692 / 1694 6.2 161.3 1.0X -Native ORC Vectorized 471 / 493 22.3 44.9 3.6X -Hive built-in ORC 2398 / 2411 4.4 228.7 0.7X +String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Native ORC MR 1912 1917 6 5.5 182.4 1.0X +Native ORC Vectorized 477 484 5 22.0 45.5 4.0X +Hive built-in ORC 2374 2386 17 4.4 226.4 0.8X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Native ORC MR 1371 / 1379 0.8 1307.5 1.0X -Native ORC Vectorized 121 / 135 8.6 115.8 11.3X -Hive built-in ORC 521 / 561 2.0 497.1 2.6X +Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Native ORC MR 290 350 102 3.6 276.1 1.0X +Native ORC Vectorized 155 166 15 6.7 148.2 1.9X +Hive built-in ORC 520 531 8 2.0 495.8 0.6X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Native ORC MR 2711 / 2767 0.4 2585.5 1.0X -Native ORC Vectorized 210 / 232 5.0 200.5 12.9X -Hive built-in ORC 764 / 775 1.4 728.3 3.5X +Single Column Scan from 200 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Native ORC MR 365 406 73 2.9 347.9 1.0X +Native ORC Vectorized 232 246 20 4.5 221.6 1.6X +Hive built-in ORC 794 864 62 1.3 757.6 0.5X -OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------- -Native ORC MR 3979 / 3988 0.3 3794.4 1.0X -Native ORC Vectorized 357 / 366 2.9 340.2 11.2X -Hive built-in ORC 1091 / 1095 1.0 1040.5 3.6X +Single Column Scan from 300 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Native ORC MR 501 544 40 2.1 477.6 1.0X +Native ORC Vectorized 365 386 33 2.9 348.0 1.4X +Hive built-in ORC 1153 1153 0 0.9 1099.8 0.4X diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index e7ff3a5f4be2b..7a9f5c67fc693 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -46,7 +46,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) } - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() TestHive.setCacheTables(true) // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) @@ -65,7 +65,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { RuleExecutor.resetMetrics() } - override def afterAll() { + override def afterAll(): Unit = { try { TestHive.setCacheTables(false) TimeZone.setDefault(originalTimeZone) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index c7d953a731b9b..b0cf25c3a7813 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -37,7 +37,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte private val originalLocale = Locale.getDefault private val testTempDir = Utils.createTempDir() - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() TestHive.setCacheTables(true) // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) @@ -100,7 +100,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte sql("set mapreduce.jobtracker.address=local") } - override def afterAll() { + override def afterAll(): Unit = { try { TestHive.setCacheTables(false) TimeZone.setDefault(originalTimeZone) @@ -751,7 +751,7 @@ class HiveWindowFunctionQueryFileSuite private val originalLocale = Locale.getDefault private val testTempDir = Utils.createTempDir() - override def beforeAll() { + override def beforeAll(): Unit = { super.beforeAll() TestHive.setCacheTables(true) // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) @@ -769,7 +769,7 @@ class HiveWindowFunctionQueryFileSuite // sql("set mapreduce.jobtracker.address=local") } - override def afterAll() { + override def afterAll(): Unit = { try { TestHive.setCacheTables(false) TimeZone.setDefault(originalTimeZone) diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index d37f0c8573659..f627227aa0380 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -103,14 +103,6 @@ ${hive.group} hive-metastore - - ${hive.group} - hive-contrib - - - ${hive.group}.hcatalog - hive-hcatalog-core -