diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml
index a9f757c3e2413..e2eb0683b6e59 100644
--- a/.github/workflows/master.yml
+++ b/.github/workflows/master.yml
@@ -4,6 +4,9 @@ on:
push:
branches:
- master
+ pull_request:
+ branches:
+ - master
jobs:
build:
@@ -12,16 +15,46 @@ jobs:
strategy:
matrix:
java: [ '1.8', '11' ]
- name: Build Spark with JDK ${{ matrix.java }}
+ hadoop: [ 'hadoop-2.7', 'hadoop-3.2' ]
+ exclude:
+ - java: '11'
+ hadoop: 'hadoop-2.7'
+ name: Build Spark with JDK ${{ matrix.java }} and ${{ matrix.hadoop }}
steps:
- uses: actions/checkout@master
- name: Set up JDK ${{ matrix.java }}
uses: actions/setup-java@v1
with:
- version: ${{ matrix.java }}
+ java-version: ${{ matrix.java }}
- name: Build with Maven
run: |
- export MAVEN_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=512m -Dorg.slf4j.simpleLogger.defaultLogLevel=WARN"
+ export MAVEN_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=1g -Dorg.slf4j.simpleLogger.defaultLogLevel=WARN"
export MAVEN_CLI_OPTS="--no-transfer-progress"
- ./build/mvn $MAVEN_CLI_OPTS -DskipTests -Pyarn -Pmesos -Pkubernetes -Phive -Phive-thriftserver -Phadoop-3.2 -Phadoop-cloud -Djava.version=${{ matrix.java }} package
+ ./build/mvn $MAVEN_CLI_OPTS -DskipTests -Pyarn -Pmesos -Pkubernetes -Phive -Phive-thriftserver -P${{ matrix.hadoop }} -Phadoop-cloud -Djava.version=${{ matrix.java }} package
+
+
+ lint:
+ runs-on: ubuntu-latest
+ name: Linters
+ steps:
+ - uses: actions/checkout@master
+ - uses: actions/setup-java@v1
+ with:
+ java-version: '11'
+ - uses: actions/setup-python@v1
+ with:
+ python-version: '3.x'
+ architecture: 'x64'
+ - name: Scala
+ run: ./dev/lint-scala
+ - name: Java
+ run: ./dev/lint-java
+ - name: Python
+ run: |
+ pip install flake8 sphinx numpy
+ ./dev/lint-python
+ - name: License
+ run: ./dev/check-license
+ - name: Dependencies
+ run: ./dev/test-dependencies.sh
diff --git a/LICENSE-binary b/LICENSE-binary
index ba20eea118687..d2a189a3fca11 100644
--- a/LICENSE-binary
+++ b/LICENSE-binary
@@ -218,6 +218,7 @@ javax.jdo:jdo-api
joda-time:joda-time
net.sf.opencsv:opencsv
org.apache.derby:derby
+org.ehcache:ehcache
org.objenesis:objenesis
org.roaringbitmap:RoaringBitmap
org.scalanlp:breeze-macros_2.12
@@ -259,6 +260,7 @@ net.sf.supercsv:super-csv
org.apache.arrow:arrow-format
org.apache.arrow:arrow-memory
org.apache.arrow:arrow-vector
+org.apache.commons:commons-configuration2
org.apache.commons:commons-crypto
org.apache.commons:commons-lang3
org.apache.hadoop:hadoop-annotations
@@ -266,6 +268,7 @@ org.apache.hadoop:hadoop-auth
org.apache.hadoop:hadoop-client
org.apache.hadoop:hadoop-common
org.apache.hadoop:hadoop-hdfs
+org.apache.hadoop:hadoop-hdfs-client
org.apache.hadoop:hadoop-mapreduce-client-app
org.apache.hadoop:hadoop-mapreduce-client-common
org.apache.hadoop:hadoop-mapreduce-client-core
@@ -278,6 +281,21 @@ org.apache.hadoop:hadoop-yarn-server-common
org.apache.hadoop:hadoop-yarn-server-web-proxy
org.apache.httpcomponents:httpclient
org.apache.httpcomponents:httpcore
+org.apache.kerby:kerb-admin
+org.apache.kerby:kerb-client
+org.apache.kerby:kerb-common
+org.apache.kerby:kerb-core
+org.apache.kerby:kerb-crypto
+org.apache.kerby:kerb-identity
+org.apache.kerby:kerb-server
+org.apache.kerby:kerb-simplekdc
+org.apache.kerby:kerb-util
+org.apache.kerby:kerby-asn1
+org.apache.kerby:kerby-config
+org.apache.kerby:kerby-pkix
+org.apache.kerby:kerby-util
+org.apache.kerby:kerby-xdr
+org.apache.kerby:token-provider
org.apache.orc:orc-core
org.apache.orc:orc-mapreduce
org.mortbay.jetty:jetty
@@ -292,14 +310,19 @@ com.fasterxml.jackson.core:jackson-annotations
com.fasterxml.jackson.core:jackson-core
com.fasterxml.jackson.core:jackson-databind
com.fasterxml.jackson.dataformat:jackson-dataformat-yaml
+com.fasterxml.jackson.jaxrs:jackson-jaxrs-base
+com.fasterxml.jackson.jaxrs:jackson-jaxrs-json-provider
com.fasterxml.jackson.module:jackson-module-jaxb-annotations
com.fasterxml.jackson.module:jackson-module-paranamer
com.fasterxml.jackson.module:jackson-module-scala_2.12
+com.fasterxml.woodstox:woodstox-core
com.github.mifmif:generex
+com.github.stephenc.jcip:jcip-annotations
com.google.code.findbugs:jsr305
com.google.code.gson:gson
com.google.inject:guice
com.google.inject.extensions:guice-servlet
+com.nimbusds:nimbus-jose-jwt
com.twitter:parquet-hadoop-bundle
commons-cli:commons-cli
commons-dbcp:commons-dbcp
@@ -313,6 +336,8 @@ javax.inject:javax.inject
javax.validation:validation-api
log4j:apache-log4j-extras
log4j:log4j
+net.minidev:accessors-smart
+net.minidev:json-smart
net.sf.jpam:jpam
org.apache.avro:avro
org.apache.avro:avro-ipc
@@ -328,6 +353,7 @@ org.apache.directory.server:apacheds-i18n
org.apache.directory.server:apacheds-kerberos-codec
org.apache.htrace:htrace-core
org.apache.ivy:ivy
+org.apache.geronimo.specs:geronimo-jcache_1.0_spec
org.apache.mesos:mesos
org.apache.parquet:parquet-column
org.apache.parquet:parquet-common
@@ -369,6 +395,20 @@ org.eclipse.jetty:jetty-webapp
org.eclipse.jetty:jetty-xml
org.scala-lang.modules:scala-xml_2.12
org.opencypher:okapi-shade
+com.github.joshelser:dropwizard-metrics-hadoop-metrics2-reporter
+com.zaxxer.HikariCP
+org.apache.hive:hive-common
+org.apache.hive:hive-llap-common
+org.apache.hive:hive-serde
+org.apache.hive:hive-service-rpc
+org.apache.hive:hive-shims-0.23
+org.apache.hive:hive-shims
+org.apache.hive:hive-shims-scheduler
+org.apache.hive:hive-storage-api
+org.apache.hive:hive-vector-code-gen
+org.datanucleus:javax.jdo
+com.tdunning:json
+org.apache.velocity:velocity
core/src/main/java/org/apache/spark/util/collection/TimSort.java
core/src/main/resources/org/apache/spark/ui/static/bootstrap*
@@ -387,6 +427,7 @@ BSD 2-Clause
------------
com.github.luben:zstd-jni
+dnsjava:dnsjava
javolution:javolution
com.esotericsoftware:kryo-shaded
com.esotericsoftware:minlog
@@ -394,8 +435,11 @@ com.esotericsoftware:reflectasm
com.google.protobuf:protobuf-java
org.codehaus.janino:commons-compiler
org.codehaus.janino:janino
+org.codehaus.woodstox:stax2-api
jline:jline
org.jodd:jodd-core
+com.github.wendykierp:JTransforms
+pl.edu.icm:JLargeArrays
BSD 3-Clause
@@ -408,6 +452,7 @@ org.antlr:stringtemplate
org.antlr:antlr4-runtime
antlr:antlr
com.github.fommil.netlib:core
+com.google.re2j:re2j
com.thoughtworks.paranamer:paranamer
org.scala-lang:scala-compiler
org.scala-lang:scala-library
@@ -433,8 +478,13 @@ is distributed under the 3-Clause BSD license.
MIT License
-----------
-org.spire-math:spire-macros_2.12
-org.spire-math:spire_2.12
+com.microsoft.sqlserver:mssql-jdbc
+org.typelevel:spire_2.12
+org.typelevel:spire-macros_2.12
+org.typelevel:spire-platform_2.12
+org.typelevel:spire-util_2.12
+org.typelevel:algebra_2.12:jar
+org.typelevel:cats-kernel_2.12
org.typelevel:machinist_2.12
net.razorvine:pyrolite
org.slf4j:jcl-over-slf4j
@@ -458,6 +508,7 @@ Common Development and Distribution License (CDDL) 1.0
javax.activation:activation http://www.oracle.com/technetwork/java/javase/tech/index-jsp-138795.html
javax.xml.stream:stax-api https://jcp.org/en/jsr/detail?id=173
+javax.transaction:javax.transaction-api
Common Development and Distribution License (CDDL) 1.1
@@ -496,11 +547,6 @@ Eclipse Public License (EPL) 2.0
jakarta.annotation:jakarta-annotation-api https://projects.eclipse.org/projects/ee4j.ca
jakarta.ws.rs:jakarta.ws.rs-api https://github.com/eclipse-ee4j/jaxrs-api
-Mozilla Public License (MPL) 1.1
---------------------------------
-
-com.github.rwl:jtransforms https://sourceforge.net/projects/jtransforms/
-
Python Software Foundation License
----------------------------------
diff --git a/NOTICE-binary b/NOTICE-binary
index f93e088a9a731..80ddfd10a1874 100644
--- a/NOTICE-binary
+++ b/NOTICE-binary
@@ -1135,4 +1135,356 @@ Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
-limitations under the License.
\ No newline at end of file
+limitations under the License.
+
+dropwizard-metrics-hadoop-metrics2-reporter
+Copyright 2016 Josh Elser
+
+Hive Common
+Copyright 2019 The Apache Software Foundation
+
+Hive Llap Common
+Copyright 2019 The Apache Software Foundation
+
+Hive Serde
+Copyright 2019 The Apache Software Foundation
+
+Hive Service RPC
+Copyright 2019 The Apache Software Foundation
+
+Hive Shims 0.23
+Copyright 2019 The Apache Software Foundation
+
+Hive Shims Common
+Copyright 2019 The Apache Software Foundation
+
+Hive Shims Scheduler
+Copyright 2019 The Apache Software Foundation
+
+Hive Storage API
+Copyright 2018 The Apache Software Foundation
+
+Hive Vector-Code-Gen Utilities
+Copyright 2019 The Apache Software Foundation
+
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ Copyright 2015-2015 DataNucleus
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+Apache Velocity
+
+Copyright (C) 2000-2007 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+Apache Yetus - Audience Annotations
+Copyright 2015-2017 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+Ehcache V3
+Copyright 2014-2016 Terracotta, Inc.
+
+The product includes software from the Apache Commons Lang project,
+under the Apache License 2.0 (see: org.ehcache.impl.internal.classes.commonslang)
+
+Apache Geronimo JCache Spec 1.0
+Copyright 2003-2014 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+
+Kerby-kerb Admin
+Copyright 2014-2017 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+
+Kerby-kerb Client
+Copyright 2014-2017 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+
+Kerby-kerb Common
+Copyright 2014-2017 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+
+Kerby-kerb core
+Copyright 2014-2017 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+
+Kerby-kerb Crypto
+Copyright 2014-2017 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+
+Kerby-kerb Identity
+Copyright 2014-2017 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+
+Kerby-kerb Server
+Copyright 2014-2017 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+
+Kerb Simple Kdc
+Copyright 2014-2017 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+
+Kerby-kerb Util
+Copyright 2014-2017 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+
+Kerby ASN1 Project
+Copyright 2014-2017 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+
+Kerby Config
+Copyright 2014-2017 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+
+Kerby PKIX Project
+Copyright 2014-2017 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+
+Kerby Util
+Copyright 2014-2017 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+
+Kerby XDR Project
+Copyright 2014-2017 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+
+Token provider
+Copyright 2014-2017 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
diff --git a/R/check-cran.sh b/R/check-cran.sh
index 22cc9c6b601fc..22c8f423cfd12 100755
--- a/R/check-cran.sh
+++ b/R/check-cran.sh
@@ -65,6 +65,10 @@ fi
echo "Running CRAN check with $CRAN_CHECK_OPTIONS options"
+# Remove this environment variable to allow to check suggested packages once
+# Jenkins installs arrow. See SPARK-29339.
+export _R_CHECK_FORCE_SUGGESTS_=FALSE
+
if [ -n "$NO_TESTS" ] && [ -n "$NO_MANUAL" ]
then
"$R_SCRIPT_PATH/R" CMD check $CRAN_CHECK_OPTIONS "SparkR_$VERSION.tar.gz"
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index f4780862099d3..95d3e52bef3a9 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -22,7 +22,8 @@ Suggests:
rmarkdown,
testthat,
e1071,
- survival
+ survival,
+ arrow
Collate:
'schema.R'
'generics.R'
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 43ea27b359a9c..f27ef4ee28f16 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -148,19 +148,7 @@ getDefaultSqlSource <- function() {
}
writeToFileInArrow <- function(fileName, rdf, numPartitions) {
- requireNamespace1 <- requireNamespace
-
- # R API in Arrow is not yet released in CRAN. CRAN requires to add the
- # package in requireNamespace at DESCRIPTION. Later, CRAN checks if the package is available
- # or not. Therefore, it works around by avoiding direct requireNamespace.
- # Currently, as of Arrow 0.12.0, it can be installed by install_github. See ARROW-3204.
- if (requireNamespace1("arrow", quietly = TRUE)) {
- record_batch <- get("record_batch", envir = asNamespace("arrow"), inherits = FALSE)
- RecordBatchStreamWriter <- get(
- "RecordBatchStreamWriter", envir = asNamespace("arrow"), inherits = FALSE)
- FileOutputStream <- get(
- "FileOutputStream", envir = asNamespace("arrow"), inherits = FALSE)
-
+ if (requireNamespace("arrow", quietly = TRUE)) {
numPartitions <- if (!is.null(numPartitions)) {
numToInt(numPartitions)
} else {
@@ -176,11 +164,11 @@ writeToFileInArrow <- function(fileName, rdf, numPartitions) {
stream_writer <- NULL
tryCatch({
for (rdf_slice in rdf_slices) {
- batch <- record_batch(rdf_slice)
+ batch <- arrow::record_batch(rdf_slice)
if (is.null(stream_writer)) {
- stream <- FileOutputStream(fileName)
+ stream <- arrow::FileOutputStream(fileName)
schema <- batch$schema
- stream_writer <- RecordBatchStreamWriter(stream, schema)
+ stream_writer <- arrow::RecordBatchStreamWriter(stream, schema)
}
stream_writer$write_batch(batch)
diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R
index 51ae2d2954a9a..93ba1307043a3 100644
--- a/R/pkg/R/context.R
+++ b/R/pkg/R/context.R
@@ -301,7 +301,7 @@ broadcastRDD <- function(sc, object) {
#' Set the checkpoint directory
#'
#' Set the directory under which RDDs are going to be checkpointed. The
-#' directory must be a HDFS path if running on a cluster.
+#' directory must be an HDFS path if running on a cluster.
#'
#' @param sc Spark Context to use
#' @param dirName Directory path
@@ -446,7 +446,7 @@ setLogLevel <- function(level) {
#' Set checkpoint directory
#'
#' Set the directory under which SparkDataFrame are going to be checkpointed. The directory must be
-#' a HDFS path if running on a cluster.
+#' an HDFS path if running on a cluster.
#'
#' @rdname setCheckpointDir
#' @param directory Directory path to checkpoint to
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index b38d245a0cca7..a6febb1cbd132 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -232,11 +232,7 @@ readMultipleObjectsWithKeys <- function(inputCon) {
}
readDeserializeInArrow <- function(inputCon) {
- # This is a hack to avoid CRAN check. Arrow is not uploaded into CRAN now. See ARROW-3204.
- requireNamespace1 <- requireNamespace
- if (requireNamespace1("arrow", quietly = TRUE)) {
- RecordBatchStreamReader <- get(
- "RecordBatchStreamReader", envir = asNamespace("arrow"), inherits = FALSE)
+ if (requireNamespace("arrow", quietly = TRUE)) {
# Arrow drops `as_tibble` since 0.14.0, see ARROW-5190.
useAsTibble <- exists("as_tibble", envir = asNamespace("arrow"))
@@ -246,7 +242,7 @@ readDeserializeInArrow <- function(inputCon) {
# for now.
dataLen <- readInt(inputCon)
arrowData <- readBin(inputCon, raw(), as.integer(dataLen), endian = "big")
- batches <- RecordBatchStreamReader(arrowData)$batches()
+ batches <- arrow::RecordBatchStreamReader(arrowData)$batches()
if (useAsTibble) {
as_tibble <- get("as_tibble", envir = asNamespace("arrow"))
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index eecb84572a30b..eec221c2be4bf 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -3617,7 +3617,7 @@ setMethod("size",
#' @details
#' \code{slice}: Returns an array containing all the elements in x from the index start
-#' (or starting from the end if start is negative) with the specified length.
+#' (array indices start at 1, or from the end if start is negative) with the specified length.
#'
#' @rdname column_collection_functions
#' @param start an index indicating the first element occurring in the result.
diff --git a/R/pkg/R/mllib_recommendation.R b/R/pkg/R/mllib_recommendation.R
index 9a77b07462585..d238ff93ed245 100644
--- a/R/pkg/R/mllib_recommendation.R
+++ b/R/pkg/R/mllib_recommendation.R
@@ -82,6 +82,12 @@ setClass("ALSModel", representation(jobj = "jobj"))
#' statsS <- summary(modelS)
#' }
#' @note spark.als since 2.1.0
+#' @note the input rating dataframe to the ALS implementation should be deterministic.
+#' Nondeterministic data can cause failure during fitting ALS model. For example,
+#' an order-sensitive operation like sampling after a repartition makes dataframe output
+#' nondeterministic, like \code{sample(repartition(df, 2L), FALSE, 0.5, 1618L)}.
+#' Checkpointing sampled dataframe or adding a sort before sampling can help make the
+#' dataframe deterministic.
setMethod("spark.als", signature(data = "SparkDataFrame"),
function(data, ratingCol = "rating", userCol = "user", itemCol = "item",
rank = 10, regParam = 0.1, maxIter = 10, nonnegative = FALSE,
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
index 0d6f32c8f7e1f..cb3c1c59d12ed 100644
--- a/R/pkg/R/serialize.R
+++ b/R/pkg/R/serialize.R
@@ -222,15 +222,11 @@ writeArgs <- function(con, args) {
}
writeSerializeInArrow <- function(conn, df) {
- # This is a hack to avoid CRAN check. Arrow is not uploaded into CRAN now. See ARROW-3204.
- requireNamespace1 <- requireNamespace
- if (requireNamespace1("arrow", quietly = TRUE)) {
- write_arrow <- get("write_arrow", envir = asNamespace("arrow"), inherits = FALSE)
-
+ if (requireNamespace("arrow", quietly = TRUE)) {
# There looks no way to send each batch in streaming format via socket
# connection. See ARROW-4512.
# So, it writes the whole Arrow streaming-formatted binary at once for now.
- writeRaw(conn, write_arrow(df, raw()))
+ writeRaw(conn, arrow::write_arrow(df, raw()))
} else {
stop("'arrow' package should be installed.")
}
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index 31b986c326d0c..cdb59093781fb 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -266,11 +266,12 @@ sparkR.sparkContext <- function(
#' df <- read.json(path)
#'
#' sparkR.session("local[2]", "SparkR", "/home/spark")
-#' sparkR.session("yarn-client", "SparkR", "/home/spark",
-#' list(spark.executor.memory="4g"),
+#' sparkR.session("yarn", "SparkR", "/home/spark",
+#' list(spark.executor.memory="4g", spark.submit.deployMode="client"),
#' c("one.jar", "two.jar", "three.jar"),
#' c("com.databricks:spark-avro_2.12:2.0.1"))
-#' sparkR.session(spark.master = "yarn-client", spark.executor.memory = "4g")
+#' sparkR.session(spark.master = "yarn", spark.submit.deployMode = "client",
+# spark.executor.memory = "4g")
#'}
#' @note sparkR.session since 2.0.0
sparkR.session <- function(
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index 80dc4ee634512..dfe69b7f4f1fb 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -50,7 +50,7 @@ compute <- function(mode, partition, serializer, deserializer, key,
} else {
# Check to see if inputData is a valid data.frame
stopifnot(deserializer == "byte" || deserializer == "arrow")
- stopifnot(class(inputData) == "data.frame")
+ stopifnot(is.data.frame(inputData))
}
if (mode == 2) {
diff --git a/R/pkg/tests/fulltests/test_sparkR.R b/R/pkg/tests/fulltests/test_sparkR.R
index f73fc6baeccef..4232f5ec430f6 100644
--- a/R/pkg/tests/fulltests/test_sparkR.R
+++ b/R/pkg/tests/fulltests/test_sparkR.R
@@ -36,8 +36,8 @@ test_that("sparkCheckInstall", {
# "yarn-client, mesos-client" mode, SPARK_HOME was not set
sparkHome <- ""
- master <- "yarn-client"
- deployMode <- ""
+ master <- "yarn"
+ deployMode <- "client"
expect_error(sparkCheckInstall(sparkHome, master, deployMode))
sparkHome <- ""
master <- ""
diff --git a/R/pkg/tests/fulltests/test_sparkSQL_arrow.R b/R/pkg/tests/fulltests/test_sparkSQL_arrow.R
index 825c7423e1579..97972753a78fa 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL_arrow.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL_arrow.R
@@ -101,7 +101,7 @@ test_that("dapply() Arrow optimization", {
tryCatch({
ret <- dapply(df,
function(rdf) {
- stopifnot(class(rdf) == "data.frame")
+ stopifnot(is.data.frame(rdf))
rdf
},
schema(df))
@@ -115,7 +115,7 @@ test_that("dapply() Arrow optimization", {
tryCatch({
ret <- dapply(df,
function(rdf) {
- stopifnot(class(rdf) == "data.frame")
+ stopifnot(is.data.frame(rdf))
# mtcars' hp is more then 50.
stopifnot(all(rdf$hp > 50))
rdf
@@ -199,7 +199,7 @@ test_that("gapply() Arrow optimization", {
if (length(key) > 0) {
stopifnot(is.numeric(key[[1]]))
}
- stopifnot(class(grouped) == "data.frame")
+ stopifnot(is.data.frame(grouped))
grouped
},
schema(df))
@@ -217,7 +217,7 @@ test_that("gapply() Arrow optimization", {
if (length(key) > 0) {
stopifnot(is.numeric(key[[1]]))
}
- stopifnot(class(grouped) == "data.frame")
+ stopifnot(is.data.frame(grouped))
stopifnot(length(colnames(grouped)) == 11)
# mtcars' hp is more then 50.
stopifnot(all(grouped$hp > 50))
diff --git a/appveyor.yml b/appveyor.yml
index a61436c5d2e68..00c688ba18eb6 100644
--- a/appveyor.yml
+++ b/appveyor.yml
@@ -42,13 +42,13 @@ install:
# Install maven and dependencies
- ps: .\dev\appveyor-install-dependencies.ps1
# Required package for R unit tests
- - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'e1071', 'survival'), repos='https://cloud.r-project.org/')"
+ - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'e1071', 'survival', 'arrow'), repos='https://cloud.r-project.org/')"
# Here, we use the fixed version of testthat. For more details, please see SPARK-22817.
# As of devtools 2.1.0, it requires testthat higher then 2.1.1 as a dependency. SparkR test requires testthat 1.0.2.
# Therefore, we don't use devtools but installs it directly from the archive including its dependencies.
- cmd: R -e "install.packages(c('crayon', 'praise', 'R6'), repos='https://cloud.r-project.org/')"
- cmd: R -e "install.packages('https://cloud.r-project.org/src/contrib/Archive/testthat/testthat_1.0.2.tar.gz', repos=NULL, type='source')"
- - cmd: R -e "packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival')"
+ - cmd: R -e "packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival'); packageVersion('arrow')"
build_script:
# '-Djna.nosys=true' is required to avoid kernel32.dll load failure.
diff --git a/build/mvn b/build/mvn
index f68377b3ddc71..3628be9880253 100755
--- a/build/mvn
+++ b/build/mvn
@@ -22,7 +22,7 @@ _DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
# Preserve the calling directory
_CALLING_DIR="$(pwd)"
# Options used during compilation
-_COMPILE_JVM_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=512m"
+_COMPILE_JVM_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=1g"
# Installs any application tarball given a URL, the expected tarball name,
# and, optionally, a checkable binary path to determine if the binary has
diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java
index 6af45aec3c7b2..b33c53871c32f 100644
--- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java
+++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java
@@ -252,7 +252,7 @@ private static Predicate super T> getPredicate(
return (value) -> set.contains(indexValueForEntity(getter, value));
} else {
- HashSet set = new HashSet<>(values.size());
+ HashSet> set = new HashSet<>(values.size());
for (Object key : values) {
set.add(asKey(key));
}
diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java
index b8c5fab8709ed..d2a26982d8703 100644
--- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java
+++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java
@@ -124,7 +124,7 @@ interface Accessor {
Object get(Object instance) throws ReflectiveOperationException;
- Class getType();
+ Class> getType();
}
private class FieldAccessor implements Accessor {
@@ -141,7 +141,7 @@ public Object get(Object instance) throws ReflectiveOperationException {
}
@Override
- public Class getType() {
+ public Class> getType() {
return field.getType();
}
}
@@ -160,7 +160,7 @@ public Object get(Object instance) throws ReflectiveOperationException {
}
@Override
- public Class getType() {
+ public Class> getType() {
return method.getReturnType();
}
}
diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml
index c107af9ceb415..2ee17800c10e4 100644
--- a/common/network-common/pom.xml
+++ b/common/network-common/pom.xml
@@ -35,6 +35,12 @@
+
+
+ org.scala-lang
+ scala-library
+
+
io.netty
@@ -87,13 +93,6 @@
-
-
- org.scala-lang
- scala-library
- ${scala.version}
- test
-
log4j
log4j
diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 53835d8304866..c9ef9f918ffd1 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -293,9 +293,8 @@ public void close() {
}
connectionPool.clear();
- if (workerGroup != null) {
+ if (workerGroup != null && !workerGroup.isShuttingDown()) {
workerGroup.shutdownGracefully();
- workerGroup = null;
}
}
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java
index 736059fdd1f57..490915f6de4b3 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java
@@ -112,4 +112,27 @@ public static int[] decode(ByteBuf buf) {
return ints;
}
}
+
+ /** Long integer arrays are encoded with their length followed by long integers. */
+ public static class LongArrays {
+ public static int encodedLength(long[] longs) {
+ return 4 + 8 * longs.length;
+ }
+
+ public static void encode(ByteBuf buf, long[] longs) {
+ buf.writeInt(longs.length);
+ for (long i : longs) {
+ buf.writeLong(i);
+ }
+ }
+
+ public static long[] decode(ByteBuf buf) {
+ int numLongs = buf.readInt();
+ long[] longs = new long[numLongs];
+ for (int i = 0; i < longs.length; i ++) {
+ longs[i] = buf.readLong();
+ }
+ return longs;
+ }
+ }
}
diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
index 2aec4a33bbe43..9b76981c31c57 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -217,4 +217,11 @@ public Iterable> getAll() {
assertFalse(c1.isActive());
}
}
+
+ @Test(expected = IOException.class)
+ public void closeFactoryBeforeCreateClient() throws IOException, InterruptedException {
+ TransportClientFactory factory = context.createClientFactory();
+ factory.close();
+ factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ }
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
index 037e5cf7e5222..2d7a72315cf23 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
@@ -106,7 +106,7 @@ protected void handleMessage(
numBlockIds += ids.length;
}
streamId = streamManager.registerStream(client.getClientId(),
- new ManagedBufferIterator(msg, numBlockIds), client.getChannel());
+ new ShuffleManagedBufferIterator(msg), client.getChannel());
} else {
// For the compatibility with the old version, still keep the support for OpenBlocks.
OpenBlocks msg = (OpenBlocks) msgObj;
@@ -299,21 +299,6 @@ private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) {
return mapIdAndReduceIds;
}
- ManagedBufferIterator(FetchShuffleBlocks msg, int numBlockIds) {
- final int[] mapIdAndReduceIds = new int[2 * numBlockIds];
- int idx = 0;
- for (int i = 0; i < msg.mapIds.length; i++) {
- for (int reduceId : msg.reduceIds[i]) {
- mapIdAndReduceIds[idx++] = msg.mapIds[i];
- mapIdAndReduceIds[idx++] = reduceId;
- }
- }
- assert(idx == 2 * numBlockIds);
- size = mapIdAndReduceIds.length;
- blockDataForIndexFn = index -> blockManager.getBlockData(msg.appId, msg.execId,
- msg.shuffleId, mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]);
- }
-
@Override
public boolean hasNext() {
return index < size;
@@ -328,6 +313,49 @@ public ManagedBuffer next() {
}
}
+ private class ShuffleManagedBufferIterator implements Iterator {
+
+ private int mapIdx = 0;
+ private int reduceIdx = 0;
+
+ private final String appId;
+ private final String execId;
+ private final int shuffleId;
+ private final long[] mapIds;
+ private final int[][] reduceIds;
+
+ ShuffleManagedBufferIterator(FetchShuffleBlocks msg) {
+ appId = msg.appId;
+ execId = msg.execId;
+ shuffleId = msg.shuffleId;
+ mapIds = msg.mapIds;
+ reduceIds = msg.reduceIds;
+ }
+
+ @Override
+ public boolean hasNext() {
+ // mapIds.length must equal to reduceIds.length, and the passed in FetchShuffleBlocks
+ // must have non-empty mapIds and reduceIds, see the checking logic in
+ // OneForOneBlockFetcher.
+ assert(mapIds.length != 0 && mapIds.length == reduceIds.length);
+ return mapIdx < mapIds.length && reduceIdx < reduceIds[mapIdx].length;
+ }
+
+ @Override
+ public ManagedBuffer next() {
+ final ManagedBuffer block = blockManager.getBlockData(
+ appId, execId, shuffleId, mapIds[mapIdx], reduceIds[mapIdx][reduceIdx]);
+ if (reduceIdx < reduceIds[mapIdx].length - 1) {
+ reduceIdx += 1;
+ } else {
+ reduceIdx = 0;
+ mapIdx += 1;
+ }
+ metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0);
+ return block;
+ }
+ }
+
@Override
public void channelActive(TransportClient client) {
metrics.activeConnections.inc();
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
index 50f16fc700f12..8b0d1e145a813 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
@@ -172,7 +172,7 @@ public ManagedBuffer getBlockData(
String appId,
String execId,
int shuffleId,
- int mapId,
+ long mapId,
int reduceId) {
ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId));
if (executor == null) {
@@ -296,7 +296,7 @@ private void deleteNonShuffleServiceServedFiles(String[] dirs) {
* and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId.
*/
private ManagedBuffer getSortBasedShuffleBlockData(
- ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) {
+ ExecutorShuffleInfo executor, int shuffleId, long mapId, int reduceId) {
File indexFile = ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir,
"shuffle_" + shuffleId + "_" + mapId + "_0.index");
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
index cc11e92067375..52854c86be3e6 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -24,6 +24,8 @@
import java.util.HashMap;
import com.google.common.primitives.Ints;
+import com.google.common.primitives.Longs;
+import org.apache.commons.lang3.tuple.ImmutableTriple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -111,21 +113,21 @@ private boolean isShuffleBlocks(String[] blockIds) {
*/
private FetchShuffleBlocks createFetchShuffleBlocksMsg(
String appId, String execId, String[] blockIds) {
- int shuffleId = splitBlockId(blockIds[0])[0];
- HashMap> mapIdToReduceIds = new HashMap<>();
+ int shuffleId = splitBlockId(blockIds[0]).left;
+ HashMap> mapIdToReduceIds = new HashMap<>();
for (String blockId : blockIds) {
- int[] blockIdParts = splitBlockId(blockId);
- if (blockIdParts[0] != shuffleId) {
+ ImmutableTriple blockIdParts = splitBlockId(blockId);
+ if (blockIdParts.left != shuffleId) {
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
", got:" + blockId);
}
- int mapId = blockIdParts[1];
+ long mapId = blockIdParts.middle;
if (!mapIdToReduceIds.containsKey(mapId)) {
mapIdToReduceIds.put(mapId, new ArrayList<>());
}
- mapIdToReduceIds.get(mapId).add(blockIdParts[2]);
+ mapIdToReduceIds.get(mapId).add(blockIdParts.right);
}
- int[] mapIds = Ints.toArray(mapIdToReduceIds.keySet());
+ long[] mapIds = Longs.toArray(mapIdToReduceIds.keySet());
int[][] reduceIdArr = new int[mapIds.length][];
for (int i = 0; i < mapIds.length; i++) {
reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i]));
@@ -134,17 +136,16 @@ private FetchShuffleBlocks createFetchShuffleBlocksMsg(
}
/** Split the shuffleBlockId and return shuffleId, mapId and reduceId. */
- private int[] splitBlockId(String blockId) {
+ private ImmutableTriple splitBlockId(String blockId) {
String[] blockIdParts = blockId.split("_");
if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) {
throw new IllegalArgumentException(
"Unexpected shuffle block id format: " + blockId);
}
- return new int[] {
- Integer.parseInt(blockIdParts[1]),
- Integer.parseInt(blockIdParts[2]),
- Integer.parseInt(blockIdParts[3])
- };
+ return new ImmutableTriple<>(
+ Integer.parseInt(blockIdParts[1]),
+ Long.parseLong(blockIdParts[2]),
+ Integer.parseInt(blockIdParts[3]));
}
/** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java
index 466eeb3e048a8..faa960d414bcc 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java
@@ -34,14 +34,14 @@ public class FetchShuffleBlocks extends BlockTransferMessage {
public final int shuffleId;
// The length of mapIds must equal to reduceIds.size(), for the i-th mapId in mapIds,
// it corresponds to the i-th int[] in reduceIds, which contains all reduce id for this map id.
- public final int[] mapIds;
+ public final long[] mapIds;
public final int[][] reduceIds;
public FetchShuffleBlocks(
String appId,
String execId,
int shuffleId,
- int[] mapIds,
+ long[] mapIds,
int[][] reduceIds) {
this.appId = appId;
this.execId = execId;
@@ -98,7 +98,7 @@ public int encodedLength() {
return Encoders.Strings.encodedLength(appId)
+ Encoders.Strings.encodedLength(execId)
+ 4 /* encoded length of shuffleId */
- + Encoders.IntArrays.encodedLength(mapIds)
+ + Encoders.LongArrays.encodedLength(mapIds)
+ 4 /* encoded length of reduceIds.size() */
+ encodedLengthOfReduceIds;
}
@@ -108,7 +108,7 @@ public void encode(ByteBuf buf) {
Encoders.Strings.encode(buf, appId);
Encoders.Strings.encode(buf, execId);
buf.writeInt(shuffleId);
- Encoders.IntArrays.encode(buf, mapIds);
+ Encoders.LongArrays.encode(buf, mapIds);
buf.writeInt(reduceIds.length);
for (int[] ids: reduceIds) {
Encoders.IntArrays.encode(buf, ids);
@@ -119,7 +119,7 @@ public static FetchShuffleBlocks decode(ByteBuf buf) {
String appId = Encoders.Strings.decode(buf);
String execId = Encoders.Strings.decode(buf);
int shuffleId = buf.readInt();
- int[] mapIds = Encoders.IntArrays.decode(buf);
+ long[] mapIds = Encoders.LongArrays.decode(buf);
int reduceIdsSize = buf.readInt();
int[][] reduceIds = new int[reduceIdsSize][];
for (int i = 0; i < reduceIdsSize; i++) {
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
index 649c471dc1679..ba40f4a45ac8f 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
@@ -29,7 +29,7 @@ public class BlockTransferMessagesSuite {
public void serializeOpenShuffleBlocks() {
checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" }));
checkSerializeDeserialize(new FetchShuffleBlocks(
- "app-1", "exec-2", 0, new int[] {0, 1},
+ "app-1", "exec-2", 0, new long[] {0, 1},
new int[][] {{ 0, 1 }, { 0, 1, 2 }}));
checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo(
new String[] { "/local1", "/local2" }, 32, "MyShuffleManager")));
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java
index 9c623a70424b6..6a5d04b6f417b 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java
@@ -101,7 +101,7 @@ public void testFetchShuffleBlocks() {
when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(blockMarkers[1]);
FetchShuffleBlocks fetchShuffleBlocks = new FetchShuffleBlocks(
- "app0", "exec1", 0, new int[] { 0 }, new int[][] {{ 0, 1 }});
+ "app0", "exec1", 0, new long[] { 0 }, new int[][] {{ 0, 1 }});
checkOpenBlocksReceive(fetchShuffleBlocks, blockMarkers);
verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0);
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
index 66633cc7a3595..26a11672b8068 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
@@ -64,7 +64,7 @@ public void testFetchOne() {
BlockFetchingListener listener = fetchBlocks(
blocks,
blockIds,
- new FetchShuffleBlocks("app-id", "exec-id", 0, new int[] { 0 }, new int[][] {{ 0 }}),
+ new FetchShuffleBlocks("app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0 }}),
conf);
verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0"));
@@ -100,7 +100,7 @@ public void testFetchThreeShuffleBlocks() {
BlockFetchingListener listener = fetchBlocks(
blocks,
blockIds,
- new FetchShuffleBlocks("app-id", "exec-id", 0, new int[] { 0 }, new int[][] {{ 0, 1, 2 }}),
+ new FetchShuffleBlocks("app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0, 1, 2 }}),
conf);
for (int i = 0; i < 3; i ++) {
diff --git a/common/tags/src/test/java/org/apache/spark/tags/ExtendedSQLTest.java b/common/tags/src/test/java/org/apache/spark/tags/ExtendedSQLTest.java
new file mode 100644
index 0000000000000..1c0fff1b4045d
--- /dev/null
+++ b/common/tags/src/test/java/org/apache/spark/tags/ExtendedSQLTest.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.tags;
+
+import org.scalatest.TagAnnotation;
+
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+
+@TagAnnotation
+@Retention(RetentionPolicy.RUNTIME)
+@Target({ElementType.METHOD, ElementType.TYPE})
+public @interface ExtendedSQLTest { }
diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala
index fdb81a06d41c9..72aa682bb95bc 100644
--- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala
+++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.unsafe.types
import org.apache.commons.text.similarity.LevenshteinDistance
import org.scalacheck.{Arbitrary, Gen}
-import org.scalatest.prop.GeneratorDrivenPropertyChecks
+import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks
// scalastyle:off
import org.scalatest.{FunSuite, Matchers}
@@ -28,7 +28,7 @@ import org.apache.spark.unsafe.types.UTF8String.{fromString => toUTF8}
/**
* This TestSuite utilize ScalaCheck to generate randomized inputs for UTF8String testing.
*/
-class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenPropertyChecks with Matchers {
+class UTF8StringPropertyCheckSuite extends FunSuite with ScalaCheckDrivenPropertyChecks with Matchers {
// scalastyle:on
test("toString") {
diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template
index da0b06d295252..f52d33fd64223 100644
--- a/conf/metrics.properties.template
+++ b/conf/metrics.properties.template
@@ -113,6 +113,15 @@
# /metrics/applications/json # App information
# /metrics/master/json # Master information
+# org.apache.spark.metrics.sink.PrometheusServlet
+# Name: Default: Description:
+# path VARIES* Path prefix from the web server root
+#
+# * Default path is /metrics/prometheus for all instances except the master. The
+# master has two paths:
+# /metrics/applications/prometheus # App information
+# /metrics/master/prometheus # Master information
+
# org.apache.spark.metrics.sink.GraphiteSink
# Name: Default: Description:
# host NONE Hostname of the Graphite server, must be set
@@ -192,4 +201,10 @@
#driver.source.jvm.class=org.apache.spark.metrics.source.JvmSource
-#executor.source.jvm.class=org.apache.spark.metrics.source.JvmSource
\ No newline at end of file
+#executor.source.jvm.class=org.apache.spark.metrics.source.JvmSource
+
+# Example configuration for PrometheusServlet
+#*.sink.prometheusServlet.class=org.apache.spark.metrics.sink.PrometheusServlet
+#*.sink.prometheusServlet.path=/metrics/prometheus
+#master.sink.prometheusServlet.path=/metrics/master/prometheus
+#applications.sink.prometheusServlet.path=/metrics/applications/prometheus
diff --git a/core/benchmarks/CoalescedRDDBenchmark-jdk11-results.txt b/core/benchmarks/CoalescedRDDBenchmark-jdk11-results.txt
new file mode 100644
index 0000000000000..e944111ff9e93
--- /dev/null
+++ b/core/benchmarks/CoalescedRDDBenchmark-jdk11-results.txt
@@ -0,0 +1,40 @@
+================================================================================================
+Coalesced RDD , large scale
+================================================================================================
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Coalesced RDD: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Coalesce Num Partitions: 100 Num Hosts: 1 344 360 14 0.3 3441.4 1.0X
+Coalesce Num Partitions: 100 Num Hosts: 5 283 301 22 0.4 2825.1 1.2X
+Coalesce Num Partitions: 100 Num Hosts: 10 270 271 2 0.4 2700.5 1.3X
+Coalesce Num Partitions: 100 Num Hosts: 20 272 273 1 0.4 2721.1 1.3X
+Coalesce Num Partitions: 100 Num Hosts: 40 271 272 1 0.4 2710.0 1.3X
+Coalesce Num Partitions: 100 Num Hosts: 80 266 267 2 0.4 2656.3 1.3X
+Coalesce Num Partitions: 500 Num Hosts: 1 609 619 15 0.2 6089.0 0.6X
+Coalesce Num Partitions: 500 Num Hosts: 5 338 343 6 0.3 3383.0 1.0X
+Coalesce Num Partitions: 500 Num Hosts: 10 303 306 3 0.3 3029.4 1.1X
+Coalesce Num Partitions: 500 Num Hosts: 20 286 288 2 0.4 2855.9 1.2X
+Coalesce Num Partitions: 500 Num Hosts: 40 279 282 4 0.4 2793.3 1.2X
+Coalesce Num Partitions: 500 Num Hosts: 80 273 275 3 0.4 2725.9 1.3X
+Coalesce Num Partitions: 1000 Num Hosts: 1 951 955 4 0.1 9514.1 0.4X
+Coalesce Num Partitions: 1000 Num Hosts: 5 421 429 8 0.2 4211.3 0.8X
+Coalesce Num Partitions: 1000 Num Hosts: 10 347 352 4 0.3 3473.5 1.0X
+Coalesce Num Partitions: 1000 Num Hosts: 20 309 312 5 0.3 3087.5 1.1X
+Coalesce Num Partitions: 1000 Num Hosts: 40 290 294 6 0.3 2896.4 1.2X
+Coalesce Num Partitions: 1000 Num Hosts: 80 281 286 5 0.4 2811.3 1.2X
+Coalesce Num Partitions: 5000 Num Hosts: 1 3928 3950 27 0.0 39278.0 0.1X
+Coalesce Num Partitions: 5000 Num Hosts: 5 1373 1389 27 0.1 13725.2 0.3X
+Coalesce Num Partitions: 5000 Num Hosts: 10 812 827 13 0.1 8123.3 0.4X
+Coalesce Num Partitions: 5000 Num Hosts: 20 530 540 9 0.2 5299.1 0.6X
+Coalesce Num Partitions: 5000 Num Hosts: 40 421 425 5 0.2 4210.5 0.8X
+Coalesce Num Partitions: 5000 Num Hosts: 80 335 344 12 0.3 3353.7 1.0X
+Coalesce Num Partitions: 10000 Num Hosts: 1 7116 7120 4 0.0 71159.0 0.0X
+Coalesce Num Partitions: 10000 Num Hosts: 5 2539 2598 51 0.0 25390.1 0.1X
+Coalesce Num Partitions: 10000 Num Hosts: 10 1393 1432 34 0.1 13928.1 0.2X
+Coalesce Num Partitions: 10000 Num Hosts: 20 833 1009 303 0.1 8329.2 0.4X
+Coalesce Num Partitions: 10000 Num Hosts: 40 562 563 3 0.2 5615.2 0.6X
+Coalesce Num Partitions: 10000 Num Hosts: 80 420 426 7 0.2 4204.0 0.8X
+
+
diff --git a/core/benchmarks/CoalescedRDDBenchmark-results.txt b/core/benchmarks/CoalescedRDDBenchmark-results.txt
index dd63b0adea4f2..f1b867951a074 100644
--- a/core/benchmarks/CoalescedRDDBenchmark-results.txt
+++ b/core/benchmarks/CoalescedRDDBenchmark-results.txt
@@ -2,39 +2,39 @@
Coalesced RDD , large scale
================================================================================================
-Java HotSpot(TM) 64-Bit Server VM 1.8.0_201-b09 on Windows 10 10.0
-Intel64 Family 6 Model 63 Stepping 2, GenuineIntel
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Coalesced RDD: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Coalesce Num Partitions: 100 Num Hosts: 1 346 364 24 0.3 3458.9 1.0X
-Coalesce Num Partitions: 100 Num Hosts: 5 258 264 6 0.4 2579.0 1.3X
-Coalesce Num Partitions: 100 Num Hosts: 10 242 249 7 0.4 2415.2 1.4X
-Coalesce Num Partitions: 100 Num Hosts: 20 237 242 7 0.4 2371.7 1.5X
-Coalesce Num Partitions: 100 Num Hosts: 40 230 231 1 0.4 2299.8 1.5X
-Coalesce Num Partitions: 100 Num Hosts: 80 222 233 14 0.4 2223.0 1.6X
-Coalesce Num Partitions: 500 Num Hosts: 1 659 665 5 0.2 6590.4 0.5X
-Coalesce Num Partitions: 500 Num Hosts: 5 340 381 47 0.3 3395.2 1.0X
-Coalesce Num Partitions: 500 Num Hosts: 10 279 307 47 0.4 2788.3 1.2X
-Coalesce Num Partitions: 500 Num Hosts: 20 259 261 2 0.4 2591.9 1.3X
-Coalesce Num Partitions: 500 Num Hosts: 40 241 250 15 0.4 2406.5 1.4X
-Coalesce Num Partitions: 500 Num Hosts: 80 235 237 3 0.4 2349.9 1.5X
-Coalesce Num Partitions: 1000 Num Hosts: 1 1050 1053 4 0.1 10503.2 0.3X
-Coalesce Num Partitions: 1000 Num Hosts: 5 405 407 2 0.2 4049.5 0.9X
-Coalesce Num Partitions: 1000 Num Hosts: 10 320 322 2 0.3 3202.7 1.1X
-Coalesce Num Partitions: 1000 Num Hosts: 20 276 277 0 0.4 2762.3 1.3X
-Coalesce Num Partitions: 1000 Num Hosts: 40 257 260 5 0.4 2571.2 1.3X
-Coalesce Num Partitions: 1000 Num Hosts: 80 245 252 13 0.4 2448.9 1.4X
-Coalesce Num Partitions: 5000 Num Hosts: 1 3099 3145 55 0.0 30988.6 0.1X
-Coalesce Num Partitions: 5000 Num Hosts: 5 1037 1050 20 0.1 10374.4 0.3X
-Coalesce Num Partitions: 5000 Num Hosts: 10 626 633 8 0.2 6261.8 0.6X
-Coalesce Num Partitions: 5000 Num Hosts: 20 426 431 5 0.2 4258.6 0.8X
-Coalesce Num Partitions: 5000 Num Hosts: 40 328 341 22 0.3 3275.4 1.1X
-Coalesce Num Partitions: 5000 Num Hosts: 80 272 275 4 0.4 2721.4 1.3X
-Coalesce Num Partitions: 10000 Num Hosts: 1 5516 5526 9 0.0 55156.8 0.1X
-Coalesce Num Partitions: 10000 Num Hosts: 5 1956 1992 48 0.1 19560.9 0.2X
-Coalesce Num Partitions: 10000 Num Hosts: 10 1045 1057 18 0.1 10447.4 0.3X
-Coalesce Num Partitions: 10000 Num Hosts: 20 637 658 24 0.2 6373.2 0.5X
-Coalesce Num Partitions: 10000 Num Hosts: 40 431 448 15 0.2 4312.9 0.8X
-Coalesce Num Partitions: 10000 Num Hosts: 80 326 328 2 0.3 3263.4 1.1X
+Coalesce Num Partitions: 100 Num Hosts: 1 395 401 9 0.3 3952.3 1.0X
+Coalesce Num Partitions: 100 Num Hosts: 5 296 344 42 0.3 2963.2 1.3X
+Coalesce Num Partitions: 100 Num Hosts: 10 294 308 15 0.3 2941.7 1.3X
+Coalesce Num Partitions: 100 Num Hosts: 20 316 328 13 0.3 3155.2 1.3X
+Coalesce Num Partitions: 100 Num Hosts: 40 294 316 36 0.3 2940.3 1.3X
+Coalesce Num Partitions: 100 Num Hosts: 80 292 324 30 0.3 2922.2 1.4X
+Coalesce Num Partitions: 500 Num Hosts: 1 629 687 61 0.2 6292.4 0.6X
+Coalesce Num Partitions: 500 Num Hosts: 5 354 378 42 0.3 3541.7 1.1X
+Coalesce Num Partitions: 500 Num Hosts: 10 318 338 29 0.3 3179.8 1.2X
+Coalesce Num Partitions: 500 Num Hosts: 20 306 317 11 0.3 3059.2 1.3X
+Coalesce Num Partitions: 500 Num Hosts: 40 294 311 28 0.3 2941.6 1.3X
+Coalesce Num Partitions: 500 Num Hosts: 80 288 309 34 0.3 2883.9 1.4X
+Coalesce Num Partitions: 1000 Num Hosts: 1 956 978 20 0.1 9562.2 0.4X
+Coalesce Num Partitions: 1000 Num Hosts: 5 431 452 36 0.2 4306.2 0.9X
+Coalesce Num Partitions: 1000 Num Hosts: 10 358 379 23 0.3 3581.1 1.1X
+Coalesce Num Partitions: 1000 Num Hosts: 20 324 347 20 0.3 3236.7 1.2X
+Coalesce Num Partitions: 1000 Num Hosts: 40 312 333 20 0.3 3116.8 1.3X
+Coalesce Num Partitions: 1000 Num Hosts: 80 307 342 32 0.3 3068.4 1.3X
+Coalesce Num Partitions: 5000 Num Hosts: 1 3895 3906 12 0.0 38946.8 0.1X
+Coalesce Num Partitions: 5000 Num Hosts: 5 1388 1401 19 0.1 13881.7 0.3X
+Coalesce Num Partitions: 5000 Num Hosts: 10 806 839 57 0.1 8063.7 0.5X
+Coalesce Num Partitions: 5000 Num Hosts: 20 546 573 44 0.2 5462.6 0.7X
+Coalesce Num Partitions: 5000 Num Hosts: 40 413 418 5 0.2 4134.7 1.0X
+Coalesce Num Partitions: 5000 Num Hosts: 80 345 365 23 0.3 3448.1 1.1X
+Coalesce Num Partitions: 10000 Num Hosts: 1 6933 6966 55 0.0 69328.8 0.1X
+Coalesce Num Partitions: 10000 Num Hosts: 5 2455 2499 69 0.0 24551.7 0.2X
+Coalesce Num Partitions: 10000 Num Hosts: 10 1352 1392 34 0.1 13520.2 0.3X
+Coalesce Num Partitions: 10000 Num Hosts: 20 815 853 50 0.1 8147.5 0.5X
+Coalesce Num Partitions: 10000 Num Hosts: 40 558 581 28 0.2 5578.0 0.7X
+Coalesce Num Partitions: 10000 Num Hosts: 80 416 423 5 0.2 4163.3 0.9X
diff --git a/core/benchmarks/KryoBenchmark-jdk11-results.txt b/core/benchmarks/KryoBenchmark-jdk11-results.txt
new file mode 100644
index 0000000000000..27f0b8f59f47a
--- /dev/null
+++ b/core/benchmarks/KryoBenchmark-jdk11-results.txt
@@ -0,0 +1,28 @@
+================================================================================================
+Benchmark Kryo Unsafe vs safe Serialization
+================================================================================================
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Benchmark Kryo Unsafe vs safe Serialization: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+basicTypes: Int with unsafe:true 275 288 14 3.6 275.2 1.0X
+basicTypes: Long with unsafe:true 331 336 13 3.0 330.9 0.8X
+basicTypes: Float with unsafe:true 304 305 1 3.3 304.4 0.9X
+basicTypes: Double with unsafe:true 328 332 3 3.0 328.1 0.8X
+Array: Int with unsafe:true 4 4 0 252.8 4.0 69.6X
+Array: Long with unsafe:true 6 6 0 161.5 6.2 44.5X
+Array: Float with unsafe:true 4 4 0 264.6 3.8 72.8X
+Array: Double with unsafe:true 6 7 0 160.5 6.2 44.2X
+Map of string->Double with unsafe:true 52 52 0 19.3 51.8 5.3X
+basicTypes: Int with unsafe:false 344 345 1 2.9 344.3 0.8X
+basicTypes: Long with unsafe:false 372 373 1 2.7 372.3 0.7X
+basicTypes: Float with unsafe:false 333 334 1 3.0 333.4 0.8X
+basicTypes: Double with unsafe:false 344 345 0 2.9 344.3 0.8X
+Array: Int with unsafe:false 25 25 0 40.8 24.5 11.2X
+Array: Long with unsafe:false 37 37 1 27.3 36.7 7.5X
+Array: Float with unsafe:false 11 11 0 92.1 10.9 25.4X
+Array: Double with unsafe:false 17 18 0 58.3 17.2 16.0X
+Map of string->Double with unsafe:false 51 52 1 19.4 51.5 5.3X
+
+
diff --git a/core/benchmarks/KryoBenchmark-results.txt b/core/benchmarks/KryoBenchmark-results.txt
index 91e22f3afc14f..49791e6e87e3a 100644
--- a/core/benchmarks/KryoBenchmark-results.txt
+++ b/core/benchmarks/KryoBenchmark-results.txt
@@ -2,28 +2,27 @@
Benchmark Kryo Unsafe vs safe Serialization
================================================================================================
-Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.13.6
-Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
-
-Benchmark Kryo Unsafe vs safe Serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-basicTypes: Int with unsafe:true 138 / 149 7.2 138.0 1.0X
-basicTypes: Long with unsafe:true 168 / 173 6.0 167.7 0.8X
-basicTypes: Float with unsafe:true 153 / 174 6.5 153.1 0.9X
-basicTypes: Double with unsafe:true 161 / 185 6.2 161.1 0.9X
-Array: Int with unsafe:true 2 / 3 409.7 2.4 56.5X
-Array: Long with unsafe:true 4 / 5 232.5 4.3 32.1X
-Array: Float with unsafe:true 3 / 4 367.3 2.7 50.7X
-Array: Double with unsafe:true 4 / 5 228.5 4.4 31.5X
-Map of string->Double with unsafe:true 38 / 45 26.5 37.8 3.7X
-basicTypes: Int with unsafe:false 176 / 187 5.7 175.9 0.8X
-basicTypes: Long with unsafe:false 191 / 203 5.2 191.2 0.7X
-basicTypes: Float with unsafe:false 166 / 176 6.0 166.2 0.8X
-basicTypes: Double with unsafe:false 174 / 190 5.7 174.3 0.8X
-Array: Int with unsafe:false 19 / 26 52.9 18.9 7.3X
-Array: Long with unsafe:false 27 / 31 37.7 26.5 5.2X
-Array: Float with unsafe:false 8 / 10 124.3 8.0 17.2X
-Array: Double with unsafe:false 12 / 13 83.6 12.0 11.5X
-Map of string->Double with unsafe:false 38 / 42 26.1 38.3 3.6X
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Benchmark Kryo Unsafe vs safe Serialization: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+basicTypes: Int with unsafe:true 269 290 23 3.7 269.0 1.0X
+basicTypes: Long with unsafe:true 294 295 1 3.4 293.8 0.9X
+basicTypes: Float with unsafe:true 300 301 1 3.3 300.4 0.9X
+basicTypes: Double with unsafe:true 304 305 1 3.3 304.0 0.9X
+Array: Int with unsafe:true 5 6 1 193.5 5.2 52.0X
+Array: Long with unsafe:true 8 9 1 131.2 7.6 35.3X
+Array: Float with unsafe:true 6 6 0 163.5 6.1 44.0X
+Array: Double with unsafe:true 9 10 0 108.8 9.2 29.3X
+Map of string->Double with unsafe:true 54 54 1 18.7 53.6 5.0X
+basicTypes: Int with unsafe:false 326 327 1 3.1 326.2 0.8X
+basicTypes: Long with unsafe:false 353 354 1 2.8 353.3 0.8X
+basicTypes: Float with unsafe:false 325 327 1 3.1 325.1 0.8X
+basicTypes: Double with unsafe:false 335 336 1 3.0 335.0 0.8X
+Array: Int with unsafe:false 27 28 1 36.7 27.2 9.9X
+Array: Long with unsafe:false 40 41 1 25.0 40.0 6.7X
+Array: Float with unsafe:false 12 13 1 80.8 12.4 21.7X
+Array: Double with unsafe:false 21 21 1 48.6 20.6 13.1X
+Map of string->Double with unsafe:false 56 57 1 17.8 56.1 4.8X
diff --git a/core/benchmarks/KryoSerializerBenchmark-jdk11-results.txt b/core/benchmarks/KryoSerializerBenchmark-jdk11-results.txt
new file mode 100644
index 0000000000000..6b148bde12d36
--- /dev/null
+++ b/core/benchmarks/KryoSerializerBenchmark-jdk11-results.txt
@@ -0,0 +1,12 @@
+================================================================================================
+Benchmark KryoPool vs old"pool of 1" implementation
+================================================================================================
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Benchmark KryoPool vs old"pool of 1" implementation: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+KryoPool:true 6208 8374 NaN 0.0 12416876.6 1.0X
+KryoPool:false 9084 11577 724 0.0 18168947.4 0.7X
+
+
diff --git a/core/benchmarks/KryoSerializerBenchmark-results.txt b/core/benchmarks/KryoSerializerBenchmark-results.txt
index c3ce336d93241..609f3298cbc00 100644
--- a/core/benchmarks/KryoSerializerBenchmark-results.txt
+++ b/core/benchmarks/KryoSerializerBenchmark-results.txt
@@ -1,12 +1,12 @@
================================================================================================
-Benchmark KryoPool vs "pool of 1"
+Benchmark KryoPool vs old"pool of 1" implementation
================================================================================================
-Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.14
-Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz
-Benchmark KryoPool vs "pool of 1": Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-KryoPool:true 2682 / 3425 0.0 5364627.9 1.0X
-KryoPool:false 8176 / 9292 0.0 16351252.2 0.3X
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Benchmark KryoPool vs old"pool of 1" implementation: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+KryoPool:true 6012 7586 NaN 0.0 12023020.2 1.0X
+KryoPool:false 9289 11566 909 0.0 18578683.1 0.6X
diff --git a/core/benchmarks/PropertiesCloneBenchmark-jdk11-results.txt b/core/benchmarks/PropertiesCloneBenchmark-jdk11-results.txt
new file mode 100644
index 0000000000000..605b856d53382
--- /dev/null
+++ b/core/benchmarks/PropertiesCloneBenchmark-jdk11-results.txt
@@ -0,0 +1,40 @@
+================================================================================================
+Properties Cloning
+================================================================================================
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Empty Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+SerializationUtils.clone 0 0 0 0.1 11539.0 1.0X
+Utils.cloneProperties 0 0 0 1.7 572.0 20.2X
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+System Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+SerializationUtils.clone 0 0 0 0.0 217514.0 1.0X
+Utils.cloneProperties 0 0 0 0.2 5387.0 40.4X
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Small Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+SerializationUtils.clone 1 1 0 0.0 634574.0 1.0X
+Utils.cloneProperties 0 0 0 0.3 3082.0 205.9X
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Medium Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+SerializationUtils.clone 3 3 0 0.0 2576565.0 1.0X
+Utils.cloneProperties 0 0 0 0.1 16071.0 160.3X
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Large Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+SerializationUtils.clone 5 5 0 0.0 5027248.0 1.0X
+Utils.cloneProperties 0 0 0 0.0 31842.0 157.9X
+
+
diff --git a/core/benchmarks/PropertiesCloneBenchmark-results.txt b/core/benchmarks/PropertiesCloneBenchmark-results.txt
new file mode 100644
index 0000000000000..5d332a147c698
--- /dev/null
+++ b/core/benchmarks/PropertiesCloneBenchmark-results.txt
@@ -0,0 +1,40 @@
+================================================================================================
+Properties Cloning
+================================================================================================
+
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Empty Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+SerializationUtils.clone 0 0 0 0.1 13640.0 1.0X
+Utils.cloneProperties 0 0 0 1.6 608.0 22.4X
+
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+System Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+SerializationUtils.clone 0 0 0 0.0 238968.0 1.0X
+Utils.cloneProperties 0 0 0 0.4 2318.0 103.1X
+
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Small Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+SerializationUtils.clone 1 1 0 0.0 725849.0 1.0X
+Utils.cloneProperties 0 0 0 0.3 2900.0 250.3X
+
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Medium Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+SerializationUtils.clone 3 3 0 0.0 2999676.0 1.0X
+Utils.cloneProperties 0 0 0 0.1 11734.0 255.6X
+
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Large Properties: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+SerializationUtils.clone 6 6 1 0.0 5846410.0 1.0X
+Utils.cloneProperties 0 0 0 0.0 22405.0 260.9X
+
+
diff --git a/core/benchmarks/XORShiftRandomBenchmark-jdk11-results.txt b/core/benchmarks/XORShiftRandomBenchmark-jdk11-results.txt
new file mode 100644
index 0000000000000..9aa10e4835a2f
--- /dev/null
+++ b/core/benchmarks/XORShiftRandomBenchmark-jdk11-results.txt
@@ -0,0 +1,44 @@
+================================================================================================
+Pseudo random
+================================================================================================
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+nextInt: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+java.util.Random 1362 1362 0 73.4 13.6 1.0X
+XORShiftRandom 227 227 0 440.6 2.3 6.0X
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+nextLong: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+java.util.Random 2725 2726 1 36.7 27.3 1.0X
+XORShiftRandom 694 694 1 144.1 6.9 3.9X
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+nextDouble: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+java.util.Random 2727 2728 0 36.7 27.3 1.0X
+XORShiftRandom 693 694 0 144.2 6.9 3.9X
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+nextGaussian: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+java.util.Random 7012 7016 4 14.3 70.1 1.0X
+XORShiftRandom 6065 6067 1 16.5 60.7 1.2X
+
+
+================================================================================================
+hash seed
+================================================================================================
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Hash seed: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+XORShiftRandom.hashSeed 36 37 1 276.5 3.6 1.0X
+
+
diff --git a/core/benchmarks/XORShiftRandomBenchmark-results.txt b/core/benchmarks/XORShiftRandomBenchmark-results.txt
index 1140489e4a7f3..4b069878b2e9b 100644
--- a/core/benchmarks/XORShiftRandomBenchmark-results.txt
+++ b/core/benchmarks/XORShiftRandomBenchmark-results.txt
@@ -2,43 +2,43 @@
Pseudo random
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-nextInt: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-java.util.Random 1362 / 1362 73.4 13.6 1.0X
-XORShiftRandom 227 / 227 440.6 2.3 6.0X
+nextInt: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+java.util.Random 1362 1396 59 73.4 13.6 1.0X
+XORShiftRandom 227 227 0 440.7 2.3 6.0X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-nextLong: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-java.util.Random 2732 / 2732 36.6 27.3 1.0X
-XORShiftRandom 629 / 629 159.0 6.3 4.3X
+nextLong: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+java.util.Random 2732 2732 1 36.6 27.3 1.0X
+XORShiftRandom 630 630 1 158.7 6.3 4.3X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-nextDouble: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-java.util.Random 2730 / 2730 36.6 27.3 1.0X
-XORShiftRandom 629 / 629 159.0 6.3 4.3X
+nextDouble: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+java.util.Random 2731 2732 1 36.6 27.3 1.0X
+XORShiftRandom 630 630 0 158.8 6.3 4.3X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-nextGaussian: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-java.util.Random 10288 / 10288 9.7 102.9 1.0X
-XORShiftRandom 6351 / 6351 15.7 63.5 1.6X
+nextGaussian: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+java.util.Random 8895 8899 4 11.2 88.9 1.0X
+XORShiftRandom 5049 5052 5 19.8 50.5 1.8X
================================================================================================
hash seed
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-Hash seed: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-XORShiftRandom.hashSeed 1193 / 1195 8.4 119.3 1.0X
+Hash seed: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+XORShiftRandom.hashSeed 67 68 1 148.8 6.7 1.0X
diff --git a/core/pom.xml b/core/pom.xml
index 42fc2c4b3a287..38eb8adac500e 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -384,6 +384,11 @@
curator-test
test
+
+ org.apache.hadoop
+ hadoop-minikdc
+ test
+
net.razorvine
pyrolite
@@ -551,6 +556,15 @@
+
+ scala-2.13
+
+
+ org.scala-lang.modules
+ scala-parallel-collections_${scala.binary.version}
+
+
+
diff --git a/core/src/main/java/org/apache/spark/ExecutorPlugin.java b/core/src/main/java/org/apache/spark/ExecutorPlugin.java
index f86520c81df33..b25c46266247e 100644
--- a/core/src/main/java/org/apache/spark/ExecutorPlugin.java
+++ b/core/src/main/java/org/apache/spark/ExecutorPlugin.java
@@ -40,12 +40,15 @@ public interface ExecutorPlugin {
* Initialize the executor plugin.
*
* Each executor will, during its initialization, invoke this method on each
- * plugin provided in the spark.executor.plugins configuration.
+ * plugin provided in the spark.executor.plugins configuration. The Spark executor
+ * will wait on the completion of the execution of the init method.
*
* Plugins should create threads in their implementation of this method for
* any polling, blocking, or intensive computation.
+ *
+ * @param pluginContext Context information for the executor where the plugin is running.
*/
- default void init() {}
+ default void init(ExecutorPluginContext pluginContext) {}
/**
* Clean up and terminate this plugin.
diff --git a/core/src/main/java/org/apache/spark/ExecutorPluginContext.java b/core/src/main/java/org/apache/spark/ExecutorPluginContext.java
new file mode 100644
index 0000000000000..8f018732b8217
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/ExecutorPluginContext.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import com.codahale.metrics.MetricRegistry;
+import org.apache.spark.annotation.DeveloperApi;
+import org.apache.spark.annotation.Private;
+
+/**
+ * Encapsulates information about the executor when initializing {@link ExecutorPlugin} instances.
+ */
+@DeveloperApi
+public class ExecutorPluginContext {
+
+ public final MetricRegistry metricRegistry;
+ public final SparkConf sparkConf;
+ public final String executorId;
+ public final String executorHostName;
+ public final boolean isLocal;
+
+ @Private
+ public ExecutorPluginContext(
+ MetricRegistry registry,
+ SparkConf conf,
+ String id,
+ String hostName,
+ boolean local) {
+ metricRegistry = registry;
+ sparkConf = conf;
+ executorId = id;
+ executorHostName = hostName;
+ isLocal = local;
+ }
+
+}
diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java
index 92bf0ecc1b5cb..a1e29a8c873da 100644
--- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java
+++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java
@@ -51,7 +51,6 @@ public NioBufferedFileInputStream(File file) throws IOException {
/**
* Checks weather data is left to be read from the input stream.
* @return true if data is left, false otherwise
- * @throws IOException
*/
private boolean refill() throws IOException {
if (!byteBuffer.hasRemaining()) {
diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
index 4bfd2d358f36f..9a9d0c7946549 100644
--- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
+++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
@@ -54,7 +54,7 @@ public MemoryMode getMode() {
/**
* Returns the size of used memory in bytes.
*/
- protected long getUsed() {
+ public long getUsed() {
return used;
}
@@ -78,7 +78,6 @@ public void spill() throws IOException {
* @param size the amount of memory should be released
* @param trigger the MemoryConsumer that trigger this spilling
* @return the amount of released memory in bytes
- * @throws IOException
*/
public abstract long spill(long size, MemoryConsumer trigger) throws IOException;
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java
index 70c112b78911d..d30f3dad3c940 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java
@@ -18,6 +18,7 @@
package org.apache.spark.shuffle.api;
import java.io.IOException;
+import java.util.Optional;
import org.apache.spark.annotation.Private;
@@ -39,17 +40,31 @@ public interface ShuffleExecutorComponents {
/**
* Called once per map task to create a writer that will be responsible for persisting all the
* partitioned bytes written by that map task.
- * @param shuffleId Unique identifier for the shuffle the map task is a part of
- * @param mapId Within the shuffle, the identifier of the map task
- * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task
- * with the same (shuffleId, mapId) pair can be distinguished by the
- * different values of mapTaskAttemptId.
+ *
+ * @param shuffleId Unique identifier for the shuffle the map task is a part of
+ * @param mapId An ID of the map task. The ID is unique within this Spark application.
* @param numPartitions The number of partitions that will be written by the map task. Some of
-* these partitions may be empty.
+ * these partitions may be empty.
*/
ShuffleMapOutputWriter createMapOutputWriter(
int shuffleId,
- int mapId,
- long mapTaskAttemptId,
+ long mapId,
int numPartitions) throws IOException;
+
+ /**
+ * An optional extension for creating a map output writer that can optimize the transfer of a
+ * single partition file, as the entire result of a map task, to the backing store.
+ *
+ * Most implementations should return the default {@link Optional#empty()} to indicate that
+ * they do not support this optimization. This primarily is for backwards-compatibility in
+ * preserving an optimization in the local disk shuffle storage implementation.
+ *
+ * @param shuffleId Unique identifier for the shuffle the map task is a part of
+ * @param mapId An ID of the map task. The ID is unique within this Spark application.
+ */
+ default Optional createSingleFileMapOutputWriter(
+ int shuffleId,
+ long mapId) throws IOException {
+ return Optional.empty();
+ }
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
index 7fac00b7fbc3f..21abe9a57cd25 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
@@ -39,7 +39,7 @@ public interface ShuffleMapOutputWriter {
* for the same partition within any given map task. The partition identifier will be in the
* range of precisely 0 (inclusive) to numPartitions (exclusive), where numPartitions was
* provided upon the creation of this map output writer via
- * {@link ShuffleExecutorComponents#createMapOutputWriter(int, int, long, int)}.
+ * {@link ShuffleExecutorComponents#createMapOutputWriter(int, long, int)}.
*
* Calls to this method will be invoked with monotonically increasing reducePartitionIds; each
* call to this method will be called with a reducePartitionId that is strictly greater than
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
new file mode 100644
index 0000000000000..cad8dcfda52bc
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.api;
+
+import java.io.File;
+import java.io.IOException;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * Optional extension for partition writing that is optimized for transferring a single
+ * file to the backing store.
+ */
+@Private
+public interface SingleSpillShuffleMapOutputWriter {
+
+ /**
+ * Transfer a file that contains the bytes of all the partitions written by this map task.
+ */
+ void transferMapSpillFile(File mapOutputFile, long[] partitionLengths) throws IOException;
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index f75e932860f90..dc157eaa3b253 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -85,8 +85,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
private final Partitioner partitioner;
private final ShuffleWriteMetricsReporter writeMetrics;
private final int shuffleId;
- private final int mapId;
- private final long mapTaskAttemptId;
+ private final long mapId;
private final Serializer serializer;
private final ShuffleExecutorComponents shuffleExecutorComponents;
@@ -106,8 +105,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
BypassMergeSortShuffleWriter(
BlockManager blockManager,
BypassMergeSortShuffleHandle handle,
- int mapId,
- long mapTaskAttemptId,
+ long mapId,
SparkConf conf,
ShuffleWriteMetricsReporter writeMetrics,
ShuffleExecutorComponents shuffleExecutorComponents) {
@@ -117,7 +115,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
this.blockManager = blockManager;
final ShuffleDependency dep = handle.dependency();
this.mapId = mapId;
- this.mapTaskAttemptId = mapTaskAttemptId;
this.shuffleId = dep.shuffleId();
this.partitioner = dep.partitioner();
this.numPartitions = partitioner.numPartitions();
@@ -130,11 +127,12 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
public void write(Iterator> records) throws IOException {
assert (partitionWriters == null);
ShuffleMapOutputWriter mapOutputWriter = shuffleExecutorComponents
- .createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions);
+ .createMapOutputWriter(shuffleId, mapId, numPartitions);
try {
if (!records.hasNext()) {
partitionLengths = mapOutputWriter.commitAllPartitions();
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(
+ blockManager.shuffleServerId(), partitionLengths, mapId);
return;
}
final SerializerInstance serInstance = serializer.newInstance();
@@ -167,7 +165,8 @@ public void write(Iterator> records) throws IOException {
}
partitionLengths = writePartitionedData(mapOutputWriter);
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(
+ blockManager.shuffleServerId(), partitionLengths, mapId);
} catch (Exception e) {
try {
mapOutputWriter.abort(e);
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 024756087bf7f..833744f4777ce 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -423,7 +423,6 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p
*
* @return metadata for the spill files written by this sorter. If no records were ever inserted
* into this sorter, then this will return an empty array.
- * @throws IOException
*/
public SpillInfo[] closeAndGetSpills() throws IOException {
if (inMemSorter != null) {
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 9d05f03613ce9..d09282e61a9c7 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -17,9 +17,12 @@
package org.apache.spark.shuffle.sort;
+import java.nio.channels.Channels;
+import java.util.Optional;
import javax.annotation.Nullable;
import java.io.*;
import java.nio.channels.FileChannel;
+import java.nio.channels.WritableByteChannel;
import java.util.Iterator;
import scala.Option;
@@ -31,7 +34,6 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.ByteStreams;
import com.google.common.io.Closeables;
-import com.google.common.io.Files;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -41,8 +43,6 @@
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.NioBufferedFileInputStream;
-import org.apache.commons.io.output.CloseShieldOutputStream;
-import org.apache.commons.io.output.CountingOutputStream;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
@@ -50,8 +50,12 @@
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
-import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
+import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
+import org.apache.spark.shuffle.api.ShufflePartitionWriter;
+import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
+import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
@@ -65,23 +69,21 @@ public class UnsafeShuffleWriter extends ShuffleWriter {
private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
@VisibleForTesting
- static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096;
static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024;
private final BlockManager blockManager;
- private final IndexShuffleBlockResolver shuffleBlockResolver;
private final TaskMemoryManager memoryManager;
private final SerializerInstance serializer;
private final Partitioner partitioner;
private final ShuffleWriteMetricsReporter writeMetrics;
+ private final ShuffleExecutorComponents shuffleExecutorComponents;
private final int shuffleId;
- private final int mapId;
+ private final long mapId;
private final TaskContext taskContext;
private final SparkConf sparkConf;
private final boolean transferToEnabled;
private final int initialSortBufferSize;
private final int inputBufferSizeInBytes;
- private final int outputBufferSizeInBytes;
@Nullable private MapStatus mapStatus;
@Nullable private ShuffleExternalSorter sorter;
@@ -103,27 +105,15 @@ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream
*/
private boolean stopping = false;
- private class CloseAndFlushShieldOutputStream extends CloseShieldOutputStream {
-
- CloseAndFlushShieldOutputStream(OutputStream outputStream) {
- super(outputStream);
- }
-
- @Override
- public void flush() {
- // do nothing
- }
- }
-
public UnsafeShuffleWriter(
BlockManager blockManager,
- IndexShuffleBlockResolver shuffleBlockResolver,
TaskMemoryManager memoryManager,
SerializedShuffleHandle handle,
- int mapId,
+ long mapId,
TaskContext taskContext,
SparkConf sparkConf,
- ShuffleWriteMetricsReporter writeMetrics) throws IOException {
+ ShuffleWriteMetricsReporter writeMetrics,
+ ShuffleExecutorComponents shuffleExecutorComponents) {
final int numPartitions = handle.dependency().partitioner().numPartitions();
if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
throw new IllegalArgumentException(
@@ -132,7 +122,6 @@ public UnsafeShuffleWriter(
" reduce partitions");
}
this.blockManager = blockManager;
- this.shuffleBlockResolver = shuffleBlockResolver;
this.memoryManager = memoryManager;
this.mapId = mapId;
final ShuffleDependency dep = handle.dependency();
@@ -140,6 +129,7 @@ public UnsafeShuffleWriter(
this.serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = writeMetrics;
+ this.shuffleExecutorComponents = shuffleExecutorComponents;
this.taskContext = taskContext;
this.sparkConf = sparkConf;
this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
@@ -147,8 +137,6 @@ public UnsafeShuffleWriter(
(int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE());
this.inputBufferSizeInBytes =
(int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
- this.outputBufferSizeInBytes =
- (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024;
open();
}
@@ -231,25 +219,17 @@ void closeAndWriteOutput() throws IOException {
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
final long[] partitionLengths;
- final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
- final File tmp = Utils.tempFileWith(output);
try {
- try {
- partitionLengths = mergeSpills(spills, tmp);
- } finally {
- for (SpillInfo spill : spills) {
- if (spill.file.exists() && ! spill.file.delete()) {
- logger.error("Error while deleting spill file {}", spill.file.getPath());
- }
- }
- }
- shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
+ partitionLengths = mergeSpills(spills);
} finally {
- if (tmp.exists() && !tmp.delete()) {
- logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
+ for (SpillInfo spill : spills) {
+ if (spill.file.exists() && !spill.file.delete()) {
+ logger.error("Error while deleting spill file {}", spill.file.getPath());
+ }
}
}
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(
+ blockManager.shuffleServerId(), partitionLengths, mapId);
}
@VisibleForTesting
@@ -281,137 +261,153 @@ void forceSorterToSpill() throws IOException {
*
* @return the partition lengths in the merged file.
*/
- private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
+ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
+ long[] partitionLengths;
+ if (spills.length == 0) {
+ final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents
+ .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions());
+ return mapWriter.commitAllPartitions();
+ } else if (spills.length == 1) {
+ Optional maybeSingleFileWriter =
+ shuffleExecutorComponents.createSingleFileMapOutputWriter(shuffleId, mapId);
+ if (maybeSingleFileWriter.isPresent()) {
+ // Here, we don't need to perform any metrics updates because the bytes written to this
+ // output file would have already been counted as shuffle bytes written.
+ partitionLengths = spills[0].partitionLengths;
+ maybeSingleFileWriter.get().transferMapSpillFile(spills[0].file, partitionLengths);
+ } else {
+ partitionLengths = mergeSpillsUsingStandardWriter(spills);
+ }
+ } else {
+ partitionLengths = mergeSpillsUsingStandardWriter(spills);
+ }
+ return partitionLengths;
+ }
+
+ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOException {
+ long[] partitionLengths;
final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS());
final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
final boolean fastMergeEnabled =
- (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNDAFE_FAST_MERGE_ENABLE());
+ (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE());
final boolean fastMergeIsSupported = !compressionEnabled ||
- CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
+ CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled();
+ final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents
+ .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions());
try {
- if (spills.length == 0) {
- new FileOutputStream(outputFile).close(); // Create an empty file
- return new long[partitioner.numPartitions()];
- } else if (spills.length == 1) {
- // Here, we don't need to perform any metrics updates because the bytes written to this
- // output file would have already been counted as shuffle bytes written.
- Files.move(spills[0].file, outputFile);
- return spills[0].partitionLengths;
- } else {
- final long[] partitionLengths;
- // There are multiple spills to merge, so none of these spill files' lengths were counted
- // towards our shuffle write count or shuffle write time. If we use the slow merge path,
- // then the final output file's size won't necessarily be equal to the sum of the spill
- // files' sizes. To guard against this case, we look at the output file's actual size when
- // computing shuffle bytes written.
- //
- // We allow the individual merge methods to report their own IO times since different merge
- // strategies use different IO techniques. We count IO during merge towards the shuffle
- // shuffle write time, which appears to be consistent with the "not bypassing merge-sort"
- // branch in ExternalSorter.
- if (fastMergeEnabled && fastMergeIsSupported) {
- // Compression is disabled or we are using an IO compression codec that supports
- // decompression of concatenated compressed streams, so we can perform a fast spill merge
- // that doesn't need to interpret the spilled bytes.
- if (transferToEnabled && !encryptionEnabled) {
- logger.debug("Using transferTo-based fast merge");
- partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
- } else {
- logger.debug("Using fileStream-based fast merge");
- partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
- }
+ // There are multiple spills to merge, so none of these spill files' lengths were counted
+ // towards our shuffle write count or shuffle write time. If we use the slow merge path,
+ // then the final output file's size won't necessarily be equal to the sum of the spill
+ // files' sizes. To guard against this case, we look at the output file's actual size when
+ // computing shuffle bytes written.
+ //
+ // We allow the individual merge methods to report their own IO times since different merge
+ // strategies use different IO techniques. We count IO during merge towards the shuffle
+ // write time, which appears to be consistent with the "not bypassing merge-sort" branch in
+ // ExternalSorter.
+ if (fastMergeEnabled && fastMergeIsSupported) {
+ // Compression is disabled or we are using an IO compression codec that supports
+ // decompression of concatenated compressed streams, so we can perform a fast spill merge
+ // that doesn't need to interpret the spilled bytes.
+ if (transferToEnabled && !encryptionEnabled) {
+ logger.debug("Using transferTo-based fast merge");
+ mergeSpillsWithTransferTo(spills, mapWriter);
} else {
- logger.debug("Using slow merge");
- partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
+ logger.debug("Using fileStream-based fast merge");
+ mergeSpillsWithFileStream(spills, mapWriter, null);
}
- // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
- // in-memory records, we write out the in-memory records to a file but do not count that
- // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs
- // to be counted as shuffle write, but this will lead to double-counting of the final
- // SpillInfo's bytes.
- writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
- writeMetrics.incBytesWritten(outputFile.length());
- return partitionLengths;
+ } else {
+ logger.debug("Using slow merge");
+ mergeSpillsWithFileStream(spills, mapWriter, compressionCodec);
}
- } catch (IOException e) {
- if (outputFile.exists() && !outputFile.delete()) {
- logger.error("Unable to delete output file {}", outputFile.getPath());
+ // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
+ // in-memory records, we write out the in-memory records to a file but do not count that
+ // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs
+ // to be counted as shuffle write, but this will lead to double-counting of the final
+ // SpillInfo's bytes.
+ writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
+ partitionLengths = mapWriter.commitAllPartitions();
+ } catch (Exception e) {
+ try {
+ mapWriter.abort(e);
+ } catch (Exception e2) {
+ logger.warn("Failed to abort writing the map output.", e2);
+ e.addSuppressed(e2);
}
throw e;
}
+ return partitionLengths;
}
/**
* Merges spill files using Java FileStreams. This code path is typically slower than
* the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[],
- * File)}, and it's mostly used in cases where the IO compression codec does not support
- * concatenation of compressed data, when encryption is enabled, or when users have
- * explicitly disabled use of {@code transferTo} in order to work around kernel bugs.
+ * ShuffleMapOutputWriter)}, and it's mostly used in cases where the IO compression codec
+ * does not support concatenation of compressed data, when encryption is enabled, or when
+ * users have explicitly disabled use of {@code transferTo} in order to work around kernel bugs.
* This code path might also be faster in cases where individual partition size in a spill
* is small and UnsafeShuffleWriter#mergeSpillsWithTransferTo method performs many small
* disk ios which is inefficient. In those case, Using large buffers for input and output
* files helps reducing the number of disk ios, making the file merging faster.
*
* @param spills the spills to merge.
- * @param outputFile the file to write the merged data to.
+ * @param mapWriter the map output writer to use for output.
* @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
* @return the partition lengths in the merged file.
*/
- private long[] mergeSpillsWithFileStream(
+ private void mergeSpillsWithFileStream(
SpillInfo[] spills,
- File outputFile,
+ ShuffleMapOutputWriter mapWriter,
@Nullable CompressionCodec compressionCodec) throws IOException {
- assert (spills.length >= 2);
final int numPartitions = partitioner.numPartitions();
- final long[] partitionLengths = new long[numPartitions];
final InputStream[] spillInputStreams = new InputStream[spills.length];
- final OutputStream bos = new BufferedOutputStream(
- new FileOutputStream(outputFile),
- outputBufferSizeInBytes);
- // Use a counting output stream to avoid having to close the underlying file and ask
- // the file system for its size after each partition is written.
- final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos);
-
boolean threwException = true;
try {
for (int i = 0; i < spills.length; i++) {
spillInputStreams[i] = new NioBufferedFileInputStream(
- spills[i].file,
- inputBufferSizeInBytes);
+ spills[i].file,
+ inputBufferSizeInBytes);
}
for (int partition = 0; partition < numPartitions; partition++) {
- final long initialFileLength = mergedFileOutputStream.getByteCount();
- // Shield the underlying output stream from close() and flush() calls, so that we can close
- // the higher level streams to make sure all data is really flushed and internal state is
- // cleaned.
- OutputStream partitionOutput = new CloseAndFlushShieldOutputStream(
- new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream));
- partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput);
- if (compressionCodec != null) {
- partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
- }
- for (int i = 0; i < spills.length; i++) {
- final long partitionLengthInSpill = spills[i].partitionLengths[partition];
- if (partitionLengthInSpill > 0) {
- InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i],
- partitionLengthInSpill, false);
- try {
- partitionInputStream = blockManager.serializerManager().wrapForEncryption(
- partitionInputStream);
- if (compressionCodec != null) {
- partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
+ boolean copyThrewException = true;
+ ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition);
+ OutputStream partitionOutput = writer.openStream();
+ try {
+ partitionOutput = new TimeTrackingOutputStream(writeMetrics, partitionOutput);
+ partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput);
+ if (compressionCodec != null) {
+ partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
+ }
+ for (int i = 0; i < spills.length; i++) {
+ final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+
+ if (partitionLengthInSpill > 0) {
+ InputStream partitionInputStream = null;
+ boolean copySpillThrewException = true;
+ try {
+ partitionInputStream = new LimitedInputStream(spillInputStreams[i],
+ partitionLengthInSpill, false);
+ partitionInputStream = blockManager.serializerManager().wrapForEncryption(
+ partitionInputStream);
+ if (compressionCodec != null) {
+ partitionInputStream = compressionCodec.compressedInputStream(
+ partitionInputStream);
+ }
+ ByteStreams.copy(partitionInputStream, partitionOutput);
+ copySpillThrewException = false;
+ } finally {
+ Closeables.close(partitionInputStream, copySpillThrewException);
}
- ByteStreams.copy(partitionInputStream, partitionOutput);
- } finally {
- partitionInputStream.close();
}
}
+ copyThrewException = false;
+ } finally {
+ Closeables.close(partitionOutput, copyThrewException);
}
- partitionOutput.flush();
- partitionOutput.close();
- partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength);
+ long numBytesWritten = writer.getNumBytesWritten();
+ writeMetrics.incBytesWritten(numBytesWritten);
}
threwException = false;
} finally {
@@ -420,9 +416,7 @@ private long[] mergeSpillsWithFileStream(
for (InputStream stream : spillInputStreams) {
Closeables.close(stream, threwException);
}
- Closeables.close(mergedFileOutputStream, threwException);
}
- return partitionLengths;
}
/**
@@ -430,54 +424,46 @@ private long[] mergeSpillsWithFileStream(
* This is only safe when the IO compression codec and serializer support concatenation of
* serialized streams.
*
+ * @param spills the spills to merge.
+ * @param mapWriter the map output writer to use for output.
* @return the partition lengths in the merged file.
*/
- private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
- assert (spills.length >= 2);
+ private void mergeSpillsWithTransferTo(
+ SpillInfo[] spills,
+ ShuffleMapOutputWriter mapWriter) throws IOException {
final int numPartitions = partitioner.numPartitions();
- final long[] partitionLengths = new long[numPartitions];
final FileChannel[] spillInputChannels = new FileChannel[spills.length];
final long[] spillInputChannelPositions = new long[spills.length];
- FileChannel mergedFileOutputChannel = null;
boolean threwException = true;
try {
for (int i = 0; i < spills.length; i++) {
spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
}
- // This file needs to opened in append mode in order to work around a Linux kernel bug that
- // affects transferTo; see SPARK-3948 for more details.
- mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel();
-
- long bytesWrittenToMergedFile = 0;
for (int partition = 0; partition < numPartitions; partition++) {
- for (int i = 0; i < spills.length; i++) {
- final long partitionLengthInSpill = spills[i].partitionLengths[partition];
- final FileChannel spillInputChannel = spillInputChannels[i];
- final long writeStartTime = System.nanoTime();
- Utils.copyFileStreamNIO(
- spillInputChannel,
- mergedFileOutputChannel,
- spillInputChannelPositions[i],
- partitionLengthInSpill);
- spillInputChannelPositions[i] += partitionLengthInSpill;
- writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
- bytesWrittenToMergedFile += partitionLengthInSpill;
- partitionLengths[partition] += partitionLengthInSpill;
+ boolean copyThrewException = true;
+ ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition);
+ WritableByteChannelWrapper resolvedChannel = writer.openChannelWrapper()
+ .orElseGet(() -> new StreamFallbackChannelWrapper(openStreamUnchecked(writer)));
+ try {
+ for (int i = 0; i < spills.length; i++) {
+ long partitionLengthInSpill = spills[i].partitionLengths[partition];
+ final FileChannel spillInputChannel = spillInputChannels[i];
+ final long writeStartTime = System.nanoTime();
+ Utils.copyFileStreamNIO(
+ spillInputChannel,
+ resolvedChannel.channel(),
+ spillInputChannelPositions[i],
+ partitionLengthInSpill);
+ copyThrewException = false;
+ spillInputChannelPositions[i] += partitionLengthInSpill;
+ writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
+ }
+ } finally {
+ Closeables.close(resolvedChannel, copyThrewException);
}
- }
- // Check the position after transferTo loop to see if it is in the right position and raise an
- // exception if it is incorrect. The position will not be increased to the expected length
- // after calling transferTo in kernel version 2.6.32. This issue is described at
- // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
- if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) {
- throw new IOException(
- "Current position " + mergedFileOutputChannel.position() + " does not equal expected " +
- "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
- " version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
- "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
- "to disable this NIO feature."
- );
+ long numBytes = writer.getNumBytesWritten();
+ writeMetrics.incBytesWritten(numBytes);
}
threwException = false;
} finally {
@@ -487,9 +473,7 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
assert(spillInputChannelPositions[i] == spills[i].file.length());
Closeables.close(spillInputChannels[i], threwException);
}
- Closeables.close(mergedFileOutputChannel, threwException);
}
- return partitionLengths;
}
@Override
@@ -518,4 +502,30 @@ public Option stop(boolean success) {
}
}
}
+
+ private static OutputStream openStreamUnchecked(ShufflePartitionWriter writer) {
+ try {
+ return writer.openStream();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private static final class StreamFallbackChannelWrapper implements WritableByteChannelWrapper {
+ private final WritableByteChannel channel;
+
+ StreamFallbackChannelWrapper(OutputStream fallbackStream) {
+ this.channel = Channels.newChannel(fallbackStream);
+ }
+
+ @Override
+ public WritableByteChannel channel() {
+ return channel;
+ }
+
+ @Override
+ public void close() throws IOException {
+ channel.close();
+ }
+ }
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java
index 02eb710737285..a0c7d3c248d48 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java
@@ -17,6 +17,8 @@
package org.apache.spark.shuffle.sort.io;
+import java.util.Optional;
+
import com.google.common.annotations.VisibleForTesting;
import org.apache.spark.SparkConf;
@@ -24,6 +26,7 @@
import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
import org.apache.spark.storage.BlockManager;
public class LocalDiskShuffleExecutorComponents implements ShuffleExecutorComponents {
@@ -58,8 +61,7 @@ public void initializeExecutor(String appId, String execId) {
@Override
public ShuffleMapOutputWriter createMapOutputWriter(
int shuffleId,
- int mapId,
- long mapTaskAttemptId,
+ long mapId,
int numPartitions) {
if (blockResolver == null) {
throw new IllegalStateException(
@@ -68,4 +70,15 @@ public ShuffleMapOutputWriter createMapOutputWriter(
return new LocalDiskShuffleMapOutputWriter(
shuffleId, mapId, numPartitions, blockResolver, sparkConf);
}
+
+ @Override
+ public Optional createSingleFileMapOutputWriter(
+ int shuffleId,
+ long mapId) {
+ if (blockResolver == null) {
+ throw new IllegalStateException(
+ "Executor components must be initialized before getting writers.");
+ }
+ return Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver));
+ }
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
index 7fc19b1270a46..a6529fd76188a 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
@@ -24,8 +24,8 @@
import java.io.OutputStream;
import java.nio.channels.FileChannel;
import java.nio.channels.WritableByteChannel;
-
import java.util.Optional;
+
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -48,12 +48,13 @@ public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter {
LoggerFactory.getLogger(LocalDiskShuffleMapOutputWriter.class);
private final int shuffleId;
- private final int mapId;
+ private final long mapId;
private final IndexShuffleBlockResolver blockResolver;
private final long[] partitionLengths;
private final int bufferSize;
private int lastPartitionId = -1;
private long currChannelPosition;
+ private long bytesWrittenToMergedFile = 0L;
private final File outputFile;
private File outputTempFile;
@@ -63,7 +64,7 @@ public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter {
public LocalDiskShuffleMapOutputWriter(
int shuffleId,
- int mapId,
+ long mapId,
int numPartitions,
IndexShuffleBlockResolver blockResolver,
SparkConf sparkConf) {
@@ -97,6 +98,18 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I
@Override
public long[] commitAllPartitions() throws IOException {
+ // Check the position after transferTo loop to see if it is in the right position and raise a
+ // exception if it is incorrect. The position will not be increased to the expected length
+ // after calling transferTo in kernel version 2.6.32. This issue is described at
+ // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
+ if (outputFileChannel != null && outputFileChannel.position() != bytesWrittenToMergedFile) {
+ throw new IOException(
+ "Current position " + outputFileChannel.position() + " does not equal expected " +
+ "position " + bytesWrittenToMergedFile + " after transferTo. Please check your " +
+ " kernel version to see if it is 2.6.32, as there is a kernel bug which will lead " +
+ "to unexpected behavior when using transferTo. You can set " +
+ "spark.file.transferTo=false to disable this NIO feature.");
+ }
cleanUp();
File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;
blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);
@@ -133,11 +146,10 @@ private void initStream() throws IOException {
}
private void initChannel() throws IOException {
- if (outputFileStream == null) {
- outputFileStream = new FileOutputStream(outputTempFile, true);
- }
+ // This file needs to opened in append mode in order to work around a Linux kernel bug that
+ // affects transferTo; see SPARK-3948 for more details.
if (outputFileChannel == null) {
- outputFileChannel = outputFileStream.getChannel();
+ outputFileChannel = new FileOutputStream(outputTempFile, true).getChannel();
}
}
@@ -227,6 +239,7 @@ public void write(byte[] buf, int pos, int length) throws IOException {
public void close() {
isClosed = true;
partitionLengths[partitionId] = count;
+ bytesWrittenToMergedFile += count;
}
private void verifyNotClosed() {
@@ -257,6 +270,7 @@ public WritableByteChannel channel() {
@Override
public void close() throws IOException {
partitionLengths[partitionId] = getCount();
+ bytesWrittenToMergedFile += partitionLengths[partitionId];
}
}
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java
new file mode 100644
index 0000000000000..c8b41992a8919
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
+import org.apache.spark.util.Utils;
+
+public class LocalDiskSingleSpillMapOutputWriter
+ implements SingleSpillShuffleMapOutputWriter {
+
+ private final int shuffleId;
+ private final long mapId;
+ private final IndexShuffleBlockResolver blockResolver;
+
+ public LocalDiskSingleSpillMapOutputWriter(
+ int shuffleId,
+ long mapId,
+ IndexShuffleBlockResolver blockResolver) {
+ this.shuffleId = shuffleId;
+ this.mapId = mapId;
+ this.blockResolver = blockResolver;
+ }
+
+ @Override
+ public void transferMapSpillFile(
+ File mapSpillFile,
+ long[] partitionLengths) throws IOException {
+ // The map spill file already has the proper format, and it contains all of the partition data.
+ // So just transfer it directly to the destination without any merging.
+ File outputFile = blockResolver.getDataFile(shuffleId, mapId);
+ File tempFile = Utils.tempFileWith(outputFile);
+ Files.move(mapSpillFile.toPath(), tempFile.toPath());
+ blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tempFile);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index d320ba3139541..b15365fe54ad6 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -886,6 +886,7 @@ public void reset() {
numKeys = 0;
numValues = 0;
freeArray(longArray);
+ longArray = null;
while (dataPages.size() > 0) {
MemoryBlock dataPage = dataPages.removeLast();
freePage(dataPage);
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 1b206c11d9a8e..55e4e609c3c7b 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -447,8 +447,6 @@ public void insertKVRecord(Object keyBase, long keyOffset, int keyLen,
/**
* Merges another UnsafeExternalSorters into this one, the other one will be emptied.
- *
- * @throws IOException
*/
public void merge(UnsafeExternalSorter other) throws IOException {
other.spill();
diff --git a/core/src/main/resources/org/apache/spark/ui/static/stagepage.js b/core/src/main/resources/org/apache/spark/ui/static/stagepage.js
index 3ef1a76fd7202..b28c981da20a5 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/stagepage.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/stagepage.js
@@ -286,7 +286,7 @@ $(document).ready(function () {
" Show Additional Metrics" +
"" +
"" +
- "
Select All
" +
+ "
Select All
" +
"
Scheduler Delay
" +
"
Task Deserialization Time
" +
"
Shuffle Read Blocked Time
" +
diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css
index 10bceae2fbdda..3f31403eaeef3 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css
@@ -207,6 +207,12 @@ rect.getting-result-time-proportion {
border-color: #3EC0FF;
}
+.vis-timeline .vis-item.executor.added.vis-selected {
+ background-color: #00AAFF;
+ border-color: #184C66;
+ z-index: 2;
+}
+
.legend-area rect.executor-added-legend {
fill: #A0DFFF;
stroke: #3EC0FF;
@@ -217,17 +223,17 @@ rect.getting-result-time-proportion {
border-color: #FF4D6D;
}
+.vis-timeline .vis-item.executor.removed.vis-selected {
+ background-color: #FF6680;
+ border-color: #661F2C;
+ z-index: 2;
+}
+
.legend-area rect.executor-removed-legend {
fill: #FFA1B0;
stroke: #FF4D6D;
}
-.vis-timeline .vis-item.executor.vis-selected {
- background-color: #A2FCC0;
- border-color: #36F572;
- z-index: 2;
-}
-
tr.corresponding-item-hover > td, tr.corresponding-item-hover > th {
background-color: #D6FFE4 !important;
}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
old mode 100644
new mode 100755
index 3e28816ba61b6..801c449fd626f
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -245,9 +245,9 @@ a.expandbutton {
max-width: 600px;
}
-.paginate_button.active > a {
- color: #999999;
- text-decoration: underline;
+.paginate_button.active {
+ border: 1px solid #979797 !important;
+ background: white linear-gradient(to bottom, #fff 0%, #dcdcdc 100%);
}
.title-table {
@@ -263,32 +263,36 @@ a.expandbutton {
width: 200px;
}
+.select-all-div-checkbox-div {
+ width: 90px;
+}
+
.scheduler-delay-checkbox-div {
- width: 120px;
+ width: 130px;
}
.task-deserialization-time-checkbox-div {
- width: 175px;
+ width: 190px;
}
.shuffle-read-blocked-time-checkbox-div {
- width: 187px;
+ width: 200px;
}
.shuffle-remote-reads-checkbox-div {
- width: 157px;
+ width: 170px;
}
.result-serialization-time-checkbox-div {
- width: 171px;
+ width: 185px;
}
.getting-result-time-checkbox-div {
- width: 141px;
+ width: 155px;
}
.peak-execution-memory-checkbox-div {
- width: 170px;
+ width: 180px;
}
#active-tasks-table th {
diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
index 9f59295059d30..4e417679ca663 100644
--- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
+++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
@@ -107,9 +107,9 @@ private[spark] class BarrierCoordinator(
private var timerTask: TimerTask = null
// Init a TimerTask for a barrier() call.
- private def initTimerTask(): Unit = {
+ private def initTimerTask(state: ContextBarrierState): Unit = {
timerTask = new TimerTask {
- override def run(): Unit = synchronized {
+ override def run(): Unit = state.synchronized {
// Timeout current barrier() call, fail all the sync requests.
requesters.foreach(_.sendFailure(new SparkException("The coordinator didn't get all " +
s"barrier sync requests for barrier epoch $barrierEpoch from $barrierId within " +
@@ -148,7 +148,7 @@ private[spark] class BarrierCoordinator(
// If this is the first sync message received for a barrier() call, start timer to ensure
// we may timeout for the sync.
if (requesters.isEmpty) {
- initTimerTask()
+ initTimerTask(this)
timer.schedule(timerTask, timeoutInSecs * 1000)
}
// Add the requester to array of RPCCallContexts pending for reply.
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index 24c83993b1b60..dfbd7d1c6f058 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -71,7 +71,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
private val listeners = new ConcurrentLinkedQueue[CleanerListener]()
- private val cleaningThread = new Thread() { override def run() { keepCleaning() }}
+ private val cleaningThread = new Thread() { override def run(): Unit = keepCleaning() }
private val periodicGCService: ScheduledExecutorService =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("context-cleaner-periodic-gc")
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index fb051a8c0db8e..f0ac9acd90156 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -93,7 +93,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
val shuffleId: Int = _rdd.context.newShuffleId()
val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
- shuffleId, _rdd.partitions.length, this)
+ shuffleId, this)
_rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 8230533f9d245..4bdcafce0d75a 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -115,7 +115,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
@volatile private var _cancelled: Boolean = false
- override def cancel() {
+ override def cancel(): Unit = {
_cancelled = true
jobWaiter.cancel()
}
@@ -132,7 +132,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
value.get.get
}
- override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) {
+ override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit = {
jobWaiter.completionFuture onComplete {_ => func(value.get)}
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index d878fc527791a..b610e5d4d9304 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -19,10 +19,11 @@ package org.apache.spark
import java.io._
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
+import java.util.concurrent.locks.ReentrantReadWriteLock
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.JavaConverters._
-import scala.collection.mutable.{HashMap, HashSet, ListBuffer, Map}
+import scala.collection.mutable.{HashMap, ListBuffer, Map}
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.reflect.ClassTag
@@ -41,14 +42,36 @@ import org.apache.spark.util._
* Helper class used by the [[MapOutputTrackerMaster]] to perform bookkeeping for a single
* ShuffleMapStage.
*
- * This class maintains a mapping from mapIds to `MapStatus`. It also maintains a cache of
+ * This class maintains a mapping from map index to `MapStatus`. It also maintains a cache of
* serialized map statuses in order to speed up tasks' requests for map output statuses.
*
* All public methods of this class are thread-safe.
*/
private class ShuffleStatus(numPartitions: Int) {
- // All accesses to the following state must be guarded with `this.synchronized`.
+ private val (readLock, writeLock) = {
+ val lock = new ReentrantReadWriteLock()
+ (lock.readLock(), lock.writeLock())
+ }
+
+ // All accesses to the following state must be guarded with `withReadLock` or `withWriteLock`.
+ private def withReadLock[B](fn: => B): B = {
+ readLock.lock()
+ try {
+ fn
+ } finally {
+ readLock.unlock()
+ }
+ }
+
+ private def withWriteLock[B](fn: => B): B = {
+ writeLock.lock()
+ try {
+ fn
+ } finally {
+ writeLock.unlock()
+ }
+ }
/**
* MapStatus for each partition. The index of the array is the map partition id.
@@ -88,12 +111,12 @@ private class ShuffleStatus(numPartitions: Int) {
* Register a map output. If there is already a registered location for the map output then it
* will be replaced by the new location.
*/
- def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized {
- if (mapStatuses(mapId) == null) {
+ def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock {
+ if (mapStatuses(mapIndex) == null) {
_numAvailableOutputs += 1
invalidateSerializedMapOutputStatusCache()
}
- mapStatuses(mapId) = status
+ mapStatuses(mapIndex) = status
}
/**
@@ -101,10 +124,10 @@ private class ShuffleStatus(numPartitions: Int) {
* This is a no-op if there is no registered map output or if the registered output is from a
* different block manager.
*/
- def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized {
- if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) {
+ def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock {
+ if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) {
_numAvailableOutputs -= 1
- mapStatuses(mapId) = null
+ mapStatuses(mapIndex) = null
invalidateSerializedMapOutputStatusCache()
}
}
@@ -113,7 +136,7 @@ private class ShuffleStatus(numPartitions: Int) {
* Removes all shuffle outputs associated with this host. Note that this will also remove
* outputs which are served by an external shuffle server (if one exists).
*/
- def removeOutputsOnHost(host: String): Unit = {
+ def removeOutputsOnHost(host: String): Unit = withWriteLock {
removeOutputsByFilter(x => x.host == host)
}
@@ -122,7 +145,7 @@ private class ShuffleStatus(numPartitions: Int) {
* remove outputs which are served by an external shuffle server (if one exists), as they are
* still registered with that execId.
*/
- def removeOutputsOnExecutor(execId: String): Unit = synchronized {
+ def removeOutputsOnExecutor(execId: String): Unit = withWriteLock {
removeOutputsByFilter(x => x.executorId == execId)
}
@@ -130,11 +153,11 @@ private class ShuffleStatus(numPartitions: Int) {
* Removes all shuffle outputs which satisfies the filter. Note that this will also
* remove outputs which are served by an external shuffle server (if one exists).
*/
- def removeOutputsByFilter(f: (BlockManagerId) => Boolean): Unit = synchronized {
- for (mapId <- 0 until mapStatuses.length) {
- if (mapStatuses(mapId) != null && f(mapStatuses(mapId).location)) {
+ def removeOutputsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock {
+ for (mapIndex <- mapStatuses.indices) {
+ if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) {
_numAvailableOutputs -= 1
- mapStatuses(mapId) = null
+ mapStatuses(mapIndex) = null
invalidateSerializedMapOutputStatusCache()
}
}
@@ -143,14 +166,14 @@ private class ShuffleStatus(numPartitions: Int) {
/**
* Number of partitions that have shuffle outputs.
*/
- def numAvailableOutputs: Int = synchronized {
+ def numAvailableOutputs: Int = withReadLock {
_numAvailableOutputs
}
/**
* Returns the sequence of partition ids that are missing (i.e. needs to be computed).
*/
- def findMissingPartitions(): Seq[Int] = synchronized {
+ def findMissingPartitions(): Seq[Int] = withReadLock {
val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null)
assert(missing.size == numPartitions - _numAvailableOutputs,
s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
@@ -169,18 +192,31 @@ private class ShuffleStatus(numPartitions: Int) {
def serializedMapStatus(
broadcastManager: BroadcastManager,
isLocal: Boolean,
- minBroadcastSize: Int): Array[Byte] = synchronized {
- if (cachedSerializedMapStatus eq null) {
- val serResult = MapOutputTracker.serializeMapStatuses(
+ minBroadcastSize: Int): Array[Byte] = {
+ var result: Array[Byte] = null
+
+ withReadLock {
+ if (cachedSerializedMapStatus != null) {
+ result = cachedSerializedMapStatus
+ }
+ }
+
+ if (result == null) withWriteLock {
+ if (cachedSerializedMapStatus == null) {
+ val serResult = MapOutputTracker.serializeMapStatuses(
mapStatuses, broadcastManager, isLocal, minBroadcastSize)
- cachedSerializedMapStatus = serResult._1
- cachedSerializedBroadcast = serResult._2
+ cachedSerializedMapStatus = serResult._1
+ cachedSerializedBroadcast = serResult._2
+ }
+ // The following line has to be outside if statement since it's possible that another thread
+ // initializes cachedSerializedMapStatus in-between `withReadLock` and `withWriteLock`.
+ result = cachedSerializedMapStatus
}
- cachedSerializedMapStatus
+ result
}
// Used in testing.
- def hasCachedSerializedBroadcast: Boolean = synchronized {
+ def hasCachedSerializedBroadcast: Boolean = withReadLock {
cachedSerializedBroadcast != null
}
@@ -188,14 +224,14 @@ private class ShuffleStatus(numPartitions: Int) {
* Helper function which provides thread-safe access to the mapStatuses array.
* The function should NOT mutate the array.
*/
- def withMapStatuses[T](f: Array[MapStatus] => T): T = synchronized {
+ def withMapStatuses[T](f: Array[MapStatus] => T): T = withReadLock {
f(mapStatuses)
}
/**
* Clears the cached serialized map output statuses.
*/
- def invalidateSerializedMapOutputStatusCache(): Unit = synchronized {
+ def invalidateSerializedMapOutputStatusCache(): Unit = withWriteLock {
if (cachedSerializedBroadcast != null) {
// Prevent errors during broadcast cleanup from crashing the DAGScheduler (see SPARK-21444)
Utils.tryLogNonFatalError {
@@ -272,7 +308,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
/** Send a one-way message to the trackerEndpoint, to which we expect it to reply with true. */
- protected def sendTracker(message: Any) {
+ protected def sendTracker(message: Any): Unit = {
val response = askTracker[Boolean](message)
if (response != true) {
throw new SparkException(
@@ -282,8 +318,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
// For testing
def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
- : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
- getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
+ : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+ getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1, useOldFetchProtocol = false)
}
/**
@@ -292,18 +328,22 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
* endPartition is excluded from the range).
*
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
- * and the second item is a sequence of (shuffle block id, shuffle block size) tuples
- * describing the shuffle blocks that are stored at that block manager.
+ * and the second item is a sequence of (shuffle block id, shuffle block size, map index)
+ * tuples describing the shuffle blocks that are stored at that block manager.
*/
- def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
- : Iterator[(BlockManagerId, Seq[(BlockId, Long)])]
+ def getMapSizesByExecutorId(
+ shuffleId: Int,
+ startPartition: Int,
+ endPartition: Int,
+ useOldFetchProtocol: Boolean)
+ : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
/**
* Deletes map output status information for the specified shuffle stage.
*/
def unregisterShuffle(shuffleId: Int): Unit
- def stop() {}
+ def stop(): Unit = {}
}
/**
@@ -412,21 +452,21 @@ private[spark] class MapOutputTrackerMaster(
shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast)
}
- def registerShuffle(shuffleId: Int, numMaps: Int) {
+ def registerShuffle(shuffleId: Int, numMaps: Int): Unit = {
if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
}
- def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
- shuffleStatuses(shuffleId).addMapOutput(mapId, status)
+ def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Unit = {
+ shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
}
/** Unregister map output information of the given shuffle, mapper and block manager */
- def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
+ def unregisterMapOutput(shuffleId: Int, mapIndex: Int, bmAddress: BlockManagerId): Unit = {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
- shuffleStatus.removeMapOutput(mapId, bmAddress)
+ shuffleStatus.removeMapOutput(mapIndex, bmAddress)
incrementEpoch()
case None =>
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
@@ -434,7 +474,7 @@ private[spark] class MapOutputTrackerMaster(
}
/** Unregister all map output information of the given shuffle. */
- def unregisterAllMapOutput(shuffleId: Int) {
+ def unregisterAllMapOutput(shuffleId: Int): Unit = {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.removeOutputsByFilter(x => true)
@@ -446,7 +486,7 @@ private[spark] class MapOutputTrackerMaster(
}
/** Unregister shuffle data */
- def unregisterShuffle(shuffleId: Int) {
+ def unregisterShuffle(shuffleId: Int): Unit = {
shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>
shuffleStatus.invalidateSerializedMapOutputStatusCache()
}
@@ -629,7 +669,7 @@ private[spark] class MapOutputTrackerMaster(
None
}
- def incrementEpoch() {
+ def incrementEpoch(): Unit = {
epochLock.synchronized {
epoch += 1
logDebug("Increasing epoch to " + epoch)
@@ -645,20 +685,25 @@ private[spark] class MapOutputTrackerMaster(
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
// This method is only called in local-mode.
- def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
- : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ def getMapSizesByExecutorId(
+ shuffleId: Int,
+ startPartition: Int,
+ endPartition: Int,
+ useOldFetchProtocol: Boolean)
+ : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
shuffleStatuses.get(shuffleId) match {
case Some (shuffleStatus) =>
shuffleStatus.withMapStatuses { statuses =>
- MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
+ MapOutputTracker.convertMapStatuses(
+ shuffleId, startPartition, endPartition, statuses, useOldFetchProtocol)
}
case None =>
Iterator.empty
}
}
- override def stop() {
+ override def stop(): Unit = {
mapOutputRequests.offer(PoisonPill)
threadpool.shutdown()
sendTracker(StopMapOutputTracker)
@@ -685,12 +730,17 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
private val fetchingLock = new KeyLock[Int]
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
- override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
- : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ override def getMapSizesByExecutorId(
+ shuffleId: Int,
+ startPartition: Int,
+ endPartition: Int,
+ useOldFetchProtocol: Boolean)
+ : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
val statuses = getStatuses(shuffleId)
try {
- MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
+ MapOutputTracker.convertMapStatuses(
+ shuffleId, startPartition, endPartition, statuses, useOldFetchProtocol)
} catch {
case e: MetadataFetchFailedException =>
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
@@ -832,19 +882,21 @@ private[spark] object MapOutputTracker extends Logging {
* @param shuffleId Identifier for the shuffle
* @param startPartition Start of map output partition ID range (included in range)
* @param endPartition End of map output partition ID range (excluded from range)
- * @param statuses List of map statuses, indexed by map ID.
+ * @param statuses List of map statuses, indexed by map partition index.
+ * @param useOldFetchProtocol Whether to use the old shuffle fetch protocol.
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
- * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples
- * describing the shuffle blocks that are stored at that block manager.
+ * and the second item is a sequence of (shuffle block id, shuffle block size, map index)
+ * tuples describing the shuffle blocks that are stored at that block manager.
*/
def convertMapStatuses(
shuffleId: Int,
startPartition: Int,
endPartition: Int,
- statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ statuses: Array[MapStatus],
+ useOldFetchProtocol: Boolean): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
assert (statuses != null)
- val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]]
- for ((status, mapId) <- statuses.iterator.zipWithIndex) {
+ val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
+ for ((status, mapIndex) <- statuses.iterator.zipWithIndex) {
if (status == null) {
val errorMessage = s"Missing an output location for shuffle $shuffleId"
logError(errorMessage)
@@ -853,8 +905,15 @@ private[spark] object MapOutputTracker extends Logging {
for (part <- startPartition until endPartition) {
val size = status.getSizeForBlock(part)
if (size != 0) {
- splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
- ((ShuffleBlockId(shuffleId, mapId, part), size))
+ if (useOldFetchProtocol) {
+ // While we use the old shuffle fetch protocol, we use mapIndex as mapId in the
+ // ShuffleBlockId.
+ splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
+ ((ShuffleBlockId(shuffleId, mapIndex, part), size, mapIndex))
+ } else {
+ splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
+ ((ShuffleBlockId(shuffleId, status.mapTaskId, part), size, mapIndex))
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 77db0f5d0eaa7..d061627bea69c 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -108,12 +108,12 @@ private[spark] class SecurityManager(
* Admin acls should be set before the view or modify acls. If you modify the admin
* acls you should also set the view and modify acls again to pick up the changes.
*/
- def setViewAcls(defaultUsers: Set[String], allowedUsers: Seq[String]) {
+ def setViewAcls(defaultUsers: Set[String], allowedUsers: Seq[String]): Unit = {
viewAcls = adminAcls ++ defaultUsers ++ allowedUsers
logInfo("Changing view acls to: " + viewAcls.mkString(","))
}
- def setViewAcls(defaultUser: String, allowedUsers: Seq[String]) {
+ def setViewAcls(defaultUser: String, allowedUsers: Seq[String]): Unit = {
setViewAcls(Set[String](defaultUser), allowedUsers)
}
@@ -121,7 +121,7 @@ private[spark] class SecurityManager(
* Admin acls groups should be set before the view or modify acls groups. If you modify the admin
* acls groups you should also set the view and modify acls groups again to pick up the changes.
*/
- def setViewAclsGroups(allowedUserGroups: Seq[String]) {
+ def setViewAclsGroups(allowedUserGroups: Seq[String]): Unit = {
viewAclsGroups = adminAclsGroups ++ allowedUserGroups
logInfo("Changing view acls groups to: " + viewAclsGroups.mkString(","))
}
@@ -149,7 +149,7 @@ private[spark] class SecurityManager(
* Admin acls should be set before the view or modify acls. If you modify the admin
* acls you should also set the view and modify acls again to pick up the changes.
*/
- def setModifyAcls(defaultUsers: Set[String], allowedUsers: Seq[String]) {
+ def setModifyAcls(defaultUsers: Set[String], allowedUsers: Seq[String]): Unit = {
modifyAcls = adminAcls ++ defaultUsers ++ allowedUsers
logInfo("Changing modify acls to: " + modifyAcls.mkString(","))
}
@@ -158,7 +158,7 @@ private[spark] class SecurityManager(
* Admin acls groups should be set before the view or modify acls groups. If you modify the admin
* acls groups you should also set the view and modify acls groups again to pick up the changes.
*/
- def setModifyAclsGroups(allowedUserGroups: Seq[String]) {
+ def setModifyAclsGroups(allowedUserGroups: Seq[String]): Unit = {
modifyAclsGroups = adminAclsGroups ++ allowedUserGroups
logInfo("Changing modify acls groups to: " + modifyAclsGroups.mkString(","))
}
@@ -186,7 +186,7 @@ private[spark] class SecurityManager(
* Admin acls should be set before the view or modify acls. If you modify the admin
* acls you should also set the view and modify acls again to pick up the changes.
*/
- def setAdminAcls(adminUsers: Seq[String]) {
+ def setAdminAcls(adminUsers: Seq[String]): Unit = {
adminAcls = adminUsers.toSet
logInfo("Changing admin acls to: " + adminAcls.mkString(","))
}
@@ -195,12 +195,12 @@ private[spark] class SecurityManager(
* Admin acls groups should be set before the view or modify acls groups. If you modify the admin
* acls groups you should also set the view and modify acls groups again to pick up the changes.
*/
- def setAdminAclsGroups(adminUserGroups: Seq[String]) {
+ def setAdminAclsGroups(adminUserGroups: Seq[String]): Unit = {
adminAclsGroups = adminUserGroups.toSet
logInfo("Changing admin acls groups to: " + adminAclsGroups.mkString(","))
}
- def setAcls(aclSetting: Boolean) {
+ def setAcls(aclSetting: Boolean): Unit = {
aclsOn = aclSetting
logInfo("Changing acls enabled to: " + aclsOn)
}
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 24be54ec91828..3a2eaae092e8d 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -504,7 +504,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria
* Checks for illegal or deprecated config settings. Throws an exception for the former. Not
* idempotent - may mutate this conf object to convert deprecated settings to supported ones.
*/
- private[spark] def validateSettings() {
+ private[spark] def validateSettings(): Unit = {
if (contains("spark.local.dir")) {
val msg = "Note that spark.local.dir will be overridden by the value set by " +
"the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS" +
@@ -548,23 +548,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria
}
}
- if (contains("spark.master") && get("spark.master").startsWith("yarn-")) {
- val warning = s"spark.master ${get("spark.master")} is deprecated in Spark 2.0+, please " +
- "instead use \"yarn\" with specified deploy mode."
-
- get("spark.master") match {
- case "yarn-cluster" =>
- logWarning(warning)
- set("spark.master", "yarn")
- set(SUBMIT_DEPLOY_MODE, "cluster")
- case "yarn-client" =>
- logWarning(warning)
- set("spark.master", "yarn")
- set(SUBMIT_DEPLOY_MODE, "client")
- case _ => // Any other unexpected master will be checked when creating scheduler backend.
- }
- }
-
if (contains(SUBMIT_DEPLOY_MODE)) {
get(SUBMIT_DEPLOY_MODE) match {
case "cluster" | "client" =>
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 396d712bd739c..4792c0a5b664b 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -31,7 +31,6 @@ import scala.reflect.{classTag, ClassTag}
import scala.util.control.NonFatal
import com.google.common.collect.MapMaker
-import org.apache.commons.lang3.SerializationUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable}
@@ -346,7 +345,7 @@ class SparkContext(config: SparkConf) extends Logging {
override protected def childValue(parent: Properties): Properties = {
// Note: make a clone such that changes in the parent properties aren't reflected in
// the those of the children threads, which has confusing semantics (SPARK-10563).
- SerializationUtils.clone(parent)
+ Utils.cloneProperties(parent)
}
override protected def initialValue(): Properties = new Properties()
}
@@ -367,7 +366,7 @@ class SparkContext(config: SparkConf) extends Logging {
* @param logLevel The desired log level as a string.
* Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN
*/
- def setLogLevel(logLevel: String) {
+ def setLogLevel(logLevel: String): Unit = {
// let's allow lowercase or mixed case too
val upperCased = logLevel.toUpperCase(Locale.ROOT)
require(SparkContext.VALID_LOG_LEVELS.contains(upperCased),
@@ -662,7 +661,7 @@ class SparkContext(config: SparkConf) extends Logging {
private[spark] def getLocalProperties: Properties = localProperties.get()
- private[spark] def setLocalProperties(props: Properties) {
+ private[spark] def setLocalProperties(props: Properties): Unit = {
localProperties.set(props)
}
@@ -677,7 +676,7 @@ class SparkContext(config: SparkConf) extends Logging {
* implementation of thread pools have worker threads spawn other worker threads.
* As a result, local properties may propagate unpredictably.
*/
- def setLocalProperty(key: String, value: String) {
+ def setLocalProperty(key: String, value: String): Unit = {
if (value == null) {
localProperties.get.remove(key)
} else {
@@ -693,7 +692,7 @@ class SparkContext(config: SparkConf) extends Logging {
Option(localProperties.get).map(_.getProperty(key)).orNull
/** Set a human readable description of the current job. */
- def setJobDescription(value: String) {
+ def setJobDescription(value: String): Unit = {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
}
@@ -721,7 +720,8 @@ class SparkContext(config: SparkConf) extends Logging {
* are actually stopped in a timely manner, but is off by default due to HDFS-1208, where HDFS
* may respond to Thread.interrupt() by marking nodes as dead.
*/
- def setJobGroup(groupId: String, description: String, interruptOnCancel: Boolean = false) {
+ def setJobGroup(groupId: String,
+ description: String, interruptOnCancel: Boolean = false): Unit = {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description)
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
// Note: Specifying interruptOnCancel in setJobGroup (rather than cancelJobGroup) avoids
@@ -732,7 +732,7 @@ class SparkContext(config: SparkConf) extends Logging {
}
/** Clear the current thread's job group ID and its description. */
- def clearJobGroup() {
+ def clearJobGroup(): Unit = {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null)
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null)
setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null)
@@ -1560,7 +1560,7 @@ class SparkContext(config: SparkConf) extends Logging {
* Register a listener to receive up-calls from events that happen during execution.
*/
@DeveloperApi
- def addSparkListener(listener: SparkListenerInterface) {
+ def addSparkListener(listener: SparkListenerInterface): Unit = {
listenerBus.addToSharedQueue(listener)
}
@@ -1789,14 +1789,14 @@ class SparkContext(config: SparkConf) extends Logging {
/**
* Register an RDD to be persisted in memory and/or disk storage
*/
- private[spark] def persistRDD(rdd: RDD[_]) {
+ private[spark] def persistRDD(rdd: RDD[_]): Unit = {
persistentRdds(rdd.id) = rdd
}
/**
* Unpersist an RDD from memory and/or disk storage
*/
- private[spark] def unpersistRDD(rddId: Int, blocking: Boolean) {
+ private[spark] def unpersistRDD(rddId: Int, blocking: Boolean): Unit = {
env.blockManager.master.removeRdd(rddId, blocking)
persistentRdds.remove(rddId)
listenerBus.post(SparkListenerUnpersistRDD(rddId))
@@ -1812,7 +1812,7 @@ class SparkContext(config: SparkConf) extends Logging {
*
* @note A path can be added only once. Subsequent additions of the same path are ignored.
*/
- def addJar(path: String) {
+ def addJar(path: String): Unit = {
def addLocalJarFile(file: File): String = {
try {
if (!file.exists()) {
@@ -2019,7 +2019,7 @@ class SparkContext(config: SparkConf) extends Logging {
* Set the thread-local property for overriding the call sites
* of actions and RDDs.
*/
- def setCallSite(shortCallSite: String) {
+ def setCallSite(shortCallSite: String): Unit = {
setLocalProperty(CallSite.SHORT_FORM, shortCallSite)
}
@@ -2027,7 +2027,7 @@ class SparkContext(config: SparkConf) extends Logging {
* Set the thread-local property for overriding the call sites
* of actions and RDDs.
*/
- private[spark] def setCallSite(callSite: CallSite) {
+ private[spark] def setCallSite(callSite: CallSite): Unit = {
setLocalProperty(CallSite.SHORT_FORM, callSite.shortForm)
setLocalProperty(CallSite.LONG_FORM, callSite.longForm)
}
@@ -2036,7 +2036,7 @@ class SparkContext(config: SparkConf) extends Logging {
* Clear the thread-local property for overriding the call sites
* of actions and RDDs.
*/
- def clearCallSite() {
+ def clearCallSite(): Unit = {
setLocalProperty(CallSite.SHORT_FORM, null)
setLocalProperty(CallSite.LONG_FORM, null)
}
@@ -2156,8 +2156,7 @@ class SparkContext(config: SparkConf) extends Logging {
def runJob[T, U: ClassTag](
rdd: RDD[T],
processPartition: (TaskContext, Iterator[T]) => U,
- resultHandler: (Int, U) => Unit)
- {
+ resultHandler: (Int, U) => Unit): Unit = {
runJob[T, U](rdd, processPartition, 0 until rdd.partitions.length, resultHandler)
}
@@ -2171,8 +2170,7 @@ class SparkContext(config: SparkConf) extends Logging {
def runJob[T, U: ClassTag](
rdd: RDD[T],
processPartition: Iterator[T] => U,
- resultHandler: (Int, U) => Unit)
- {
+ resultHandler: (Int, U) => Unit): Unit = {
val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter)
runJob[T, U](rdd, processFunc, 0 until rdd.partitions.length, resultHandler)
}
@@ -2257,13 +2255,13 @@ class SparkContext(config: SparkConf) extends Logging {
* Cancel active jobs for the specified group. See `org.apache.spark.SparkContext.setJobGroup`
* for more information.
*/
- def cancelJobGroup(groupId: String) {
+ def cancelJobGroup(groupId: String): Unit = {
assertNotStopped()
dagScheduler.cancelJobGroup(groupId)
}
/** Cancel all jobs that have been scheduled or are running. */
- def cancelAllJobs() {
+ def cancelAllJobs(): Unit = {
assertNotStopped()
dagScheduler.cancelAllJobs()
}
@@ -2351,7 +2349,7 @@ class SparkContext(config: SparkConf) extends Logging {
* @param directory path to the directory where checkpoint files will be stored
* (must be HDFS path if running in cluster)
*/
- def setCheckpointDir(directory: String) {
+ def setCheckpointDir(directory: String): Unit = {
// If we are running on a cluster, log a warning if the directory is local.
// Otherwise, the driver may attempt to reconstruct the checkpointed RDD from
@@ -2423,7 +2421,7 @@ class SparkContext(config: SparkConf) extends Logging {
}
/** Post the application start event */
- private def postApplicationStart() {
+ private def postApplicationStart(): Unit = {
// Note: this code assumes that the task scheduler has been initialized and has contacted
// the cluster manager to get an application ID (in case the cluster manager provides one).
listenerBus.post(SparkListenerApplicationStart(appName, Some(applicationId),
@@ -2433,12 +2431,12 @@ class SparkContext(config: SparkConf) extends Logging {
}
/** Post the application end event */
- private def postApplicationEnd() {
+ private def postApplicationEnd(): Unit = {
listenerBus.post(SparkListenerApplicationEnd(System.currentTimeMillis))
}
/** Post the environment update event once the task scheduler is ready */
- private def postEnvironmentUpdate() {
+ private def postEnvironmentUpdate(): Unit = {
if (taskScheduler != null) {
val schedulingMode = getSchedulingMode.toString
val addedJarPaths = addedJars.keys.toSeq
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 419f0ab065150..78ac00909ea1a 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -70,7 +70,7 @@ class SparkEnv (
val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {
- private[spark] var isStopped = false
+ @volatile private[spark] var isStopped = false
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
// A general, soft-reference map for metadata needed during HadoopRDD split computation
@@ -79,7 +79,7 @@ class SparkEnv (
private[spark] var driverTmpDir: Option[String] = None
- private[spark] def stop() {
+ private[spark] def stop(): Unit = {
if (!isStopped) {
isStopped = true
@@ -119,7 +119,8 @@ class SparkEnv (
}
private[spark]
- def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
+ def destroyPythonWorker(pythonExec: String,
+ envVars: Map[String, String], worker: Socket): Unit = {
synchronized {
val key = (pythonExec, envVars)
pythonWorkers.get(key).foreach(_.stopWorker(worker))
@@ -127,7 +128,8 @@ class SparkEnv (
}
private[spark]
- def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
+ def releasePythonWorker(pythonExec: String,
+ envVars: Map[String, String], worker: Socket): Unit = {
synchronized {
val key = (pythonExec, envVars)
pythonWorkers.get(key).foreach(_.releaseWorker(worker))
@@ -141,7 +143,7 @@ object SparkEnv extends Logging {
private[spark] val driverSystemName = "sparkDriver"
private[spark] val executorSystemName = "sparkExecutor"
- def set(e: SparkEnv) {
+ def set(e: SparkEnv): Unit = {
env = e
}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 19f71a1dec296..b13028f868072 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -83,14 +83,15 @@ case object Resubmitted extends TaskFailedReason {
case class FetchFailed(
bmAddress: BlockManagerId, // Note that bmAddress can be null
shuffleId: Int,
- mapId: Int,
+ mapId: Long,
+ mapIndex: Int,
reduceId: Int,
message: String)
extends TaskFailedReason {
override def toErrorString: String = {
val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString
- s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " +
- s"message=\n$message\n)"
+ s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapIndex=$mapIndex, " +
+ s"mapId=$mapId, reduceId=$reduceId, message=\n$message\n)"
}
/**
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index 41ae3ae3b758a..b8c094dbea961 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -42,7 +42,6 @@ import org.json4s.JsonAST.JValue
import org.json4s.jackson.JsonMethods.{compact, render}
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.internal.config._
import org.apache.spark.scheduler._
import org.apache.spark.util.Utils
@@ -235,8 +234,10 @@ private[spark] object TestUtils {
val sslCtx = SSLContext.getInstance("SSL")
val trustManager = new X509TrustManager {
override def getAcceptedIssuers(): Array[X509Certificate] = null
- override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {}
- override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {}
+ override def checkClientTrusted(x509Certificates: Array[X509Certificate],
+ s: String): Unit = {}
+ override def checkServerTrusted(x509Certificates: Array[X509Certificate],
+ s: String): Unit = {}
}
val verifier = new HostnameVerifier() {
override def verify(hostname: String, session: SSLSession): Boolean = true
@@ -264,7 +265,7 @@ private[spark] object TestUtils {
try {
body(listener)
} finally {
- sc.listenerBus.waitUntilEmpty(TimeUnit.SECONDS.toMillis(10))
+ sc.listenerBus.waitUntilEmpty()
sc.listenerBus.removeListener(listener)
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 317f3c51d0154..aa01374a2f2e8 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -791,7 +791,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[F],
- conf: JobConf) {
+ conf: JobConf): Unit = {
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, conf)
}
@@ -800,7 +800,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
path: String,
keyClass: Class[_],
valueClass: Class[_],
- outputFormatClass: Class[F]) {
+ outputFormatClass: Class[F]): Unit = {
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
@@ -810,7 +810,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[F],
- codec: Class[_ <: CompressionCodec]) {
+ codec: Class[_ <: CompressionCodec]): Unit = {
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec)
}
@@ -820,7 +820,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[F],
- conf: Configuration) {
+ conf: Configuration): Unit = {
rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, conf)
}
@@ -828,7 +828,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* Output the RDD to any Hadoop-supported storage system, using
* a Configuration object for that storage system.
*/
- def saveAsNewAPIHadoopDataset(conf: Configuration) {
+ def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = {
rdd.saveAsNewAPIHadoopDataset(conf)
}
@@ -837,7 +837,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
path: String,
keyClass: Class[_],
valueClass: Class[_],
- outputFormatClass: Class[F]) {
+ outputFormatClass: Class[F]): Unit = {
rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
@@ -847,7 +847,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* (e.g. a table name to write to) in the same way as it would be configured for a Hadoop
* MapReduce job.
*/
- def saveAsHadoopDataset(conf: JobConf) {
+ def saveAsHadoopDataset(conf: JobConf): Unit = {
rdd.saveAsHadoopDataset(conf)
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 5ba821935ac69..1ca5262742665 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -347,7 +347,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
/**
* Applies a function f to all elements of this RDD.
*/
- def foreach(f: VoidFunction[T]) {
+ def foreach(f: VoidFunction[T]): Unit = {
rdd.foreach(x => f.call(x))
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 330c2f6e6117e..149def29b8fbd 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -546,7 +546,7 @@ class JavaSparkContext(val sc: SparkContext) extends Closeable {
def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)(fakeClassTag)
/** Shut down the SparkContext. */
- def stop() {
+ def stop(): Unit = {
sc.stop()
}
@@ -567,7 +567,7 @@ class JavaSparkContext(val sc: SparkContext) extends Closeable {
*
* @note A path can be added only once. Subsequent additions of the same path are ignored.
*/
- def addFile(path: String) {
+ def addFile(path: String): Unit = {
sc.addFile(path)
}
@@ -593,7 +593,7 @@ class JavaSparkContext(val sc: SparkContext) extends Closeable {
*
* @note A path can be added only once. Subsequent additions of the same path are ignored.
*/
- def addJar(path: String) {
+ def addJar(path: String): Unit = {
sc.addJar(path)
}
@@ -609,9 +609,9 @@ class JavaSparkContext(val sc: SparkContext) extends Closeable {
/**
* Set the directory under which RDDs are going to be checkpointed. The directory must
- * be a HDFS path if running on a cluster.
+ * be an HDFS path if running on a cluster.
*/
- def setCheckpointDir(dir: String) {
+ def setCheckpointDir(dir: String): Unit = {
sc.setCheckpointDir(dir)
}
@@ -631,14 +631,14 @@ class JavaSparkContext(val sc: SparkContext) extends Closeable {
/**
* Pass-through to SparkContext.setCallSite. For API support only.
*/
- def setCallSite(site: String) {
+ def setCallSite(site: String): Unit = {
sc.setCallSite(site)
}
/**
* Pass-through to SparkContext.setCallSite. For API support only.
*/
- def clearCallSite() {
+ def clearCallSite(): Unit = {
sc.clearCallSite()
}
@@ -669,7 +669,7 @@ class JavaSparkContext(val sc: SparkContext) extends Closeable {
* @param logLevel The desired log level as a string.
* Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN
*/
- def setLogLevel(logLevel: String) {
+ def setLogLevel(logLevel: String): Unit = {
sc.setLogLevel(logLevel)
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
index fd96052f95d3f..e9c77f4086d0d 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
@@ -81,7 +81,7 @@ private[spark] object JavaUtils {
}
}
- override def remove() {
+ override def remove(): Unit = {
prev match {
case Some(k) =>
underlying match {
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 4d76ff76e6752..6dc1721f56adf 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._
import scala.collection.mutable
+import scala.concurrent.duration.Duration
import scala.reflect.ClassTag
import org.apache.hadoop.conf.Configuration
@@ -179,15 +180,22 @@ private[spark] object PythonRDD extends Logging {
* data collected from this job, the secret for authentication, and a socket auth
* server object that can be used to join the JVM serving thread in Python.
*/
- def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
+ def toLocalIteratorAndServe[T](rdd: RDD[T], prefetchPartitions: Boolean = false): Array[Any] = {
val handleFunc = (sock: Socket) => {
val out = new DataOutputStream(sock.getOutputStream)
val in = new DataInputStream(sock.getInputStream)
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Collects a partition on each iteration
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
- rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head
+ var result: Array[Any] = null
+ rdd.sparkContext.submitJob(
+ rdd,
+ (iter: Iterator[Any]) => iter.toArray,
+ Seq(i), // The partition we are evaluating
+ (_, res: Array[Any]) => result = res,
+ result)
}
+ val prefetchIter = collectPartitionIter.buffered
// Write data until iteration is complete, client stops iteration, or error occurs
var complete = false
@@ -196,10 +204,15 @@ private[spark] object PythonRDD extends Logging {
// Read request for data, value of zero will stop iteration or non-zero to continue
if (in.readInt() == 0) {
complete = true
- } else if (collectPartitionIter.hasNext) {
+ } else if (prefetchIter.hasNext) {
// Client requested more data, attempt to collect the next partition
- val partitionArray = collectPartitionIter.next()
+ val partitionFuture = prefetchIter.next()
+ // Cause the next job to be submitted if prefetchPartitions is enabled.
+ if (prefetchPartitions) {
+ prefetchIter.headOption
+ }
+ val partitionArray = ThreadUtils.awaitResult(partitionFuture, Duration.Inf)
// Send response there is a partition to read
out.writeInt(1)
@@ -245,7 +258,7 @@ private[spark] object PythonRDD extends Logging {
new PythonBroadcast(path)
}
- def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
+ def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream): Unit = {
def write(obj: Any): Unit = obj match {
case null =>
@@ -431,7 +444,7 @@ private[spark] object PythonRDD extends Logging {
}
}
- def writeUTF(str: String, dataOut: DataOutputStream) {
+ def writeUTF(str: String, dataOut: DataOutputStream): Unit = {
val bytes = str.getBytes(StandardCharsets.UTF_8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index d2a10df7acbd3..dbbd841d0077a 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -48,6 +48,7 @@ private[spark] object PythonEvalType {
val SQL_WINDOW_AGG_PANDAS_UDF = 203
val SQL_SCALAR_PANDAS_ITER_UDF = 204
val SQL_MAP_PANDAS_ITER_UDF = 205
+ val SQL_COGROUPED_MAP_PANDAS_UDF = 206
def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
@@ -58,6 +59,7 @@ private[spark] object PythonEvalType {
case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF"
case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF"
case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
+ case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
}
}
@@ -192,7 +194,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
def exception: Option[Throwable] = Option(_exception)
/** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
- def shutdownOnTaskCompletion() {
+ def shutdownOnTaskCompletion(): Unit = {
assert(context.isCompleted)
this.interrupt()
}
@@ -410,7 +412,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}
}
- def writeUTF(str: String, dataOut: DataOutputStream) {
+ def writeUTF(str: String, dataOut: DataOutputStream): Unit = {
val bytes = str.getBytes(UTF_8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
@@ -529,7 +531,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
setDaemon(true)
- override def run() {
+ override def run(): Unit = {
// Kill the worker if it is interrupted, checking until task completion.
// TODO: This has a race condition if interruption occurs, as completed may still become true.
while (!context.isInterrupted && !context.isCompleted) {
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 6c37844a088ce..1926a5268227c 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -189,7 +189,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
null
}
- private def startDaemon() {
+ private def startDaemon(): Unit = {
self.synchronized {
// Is it already running?
if (daemon != null) {
@@ -271,7 +271,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
/**
* Redirect the given streams to our stderr in separate threads.
*/
- private def redirectStreamsToStderr(stdout: InputStream, stderr: InputStream) {
+ private def redirectStreamsToStderr(stdout: InputStream, stderr: InputStream): Unit = {
try {
new RedirectThread(stdout, System.err, "stdout reader for " + pythonExec).start()
new RedirectThread(stderr, System.err, "stderr reader for " + pythonExec).start()
@@ -288,7 +288,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
setDaemon(true)
- override def run() {
+ override def run(): Unit = {
while (true) {
self.synchronized {
if (IDLE_WORKER_TIMEOUT_NS < System.nanoTime() - lastActivityNs) {
@@ -301,7 +301,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}
- private def cleanupIdleWorkers() {
+ private def cleanupIdleWorkers(): Unit = {
while (idleWorkers.nonEmpty) {
val worker = idleWorkers.dequeue()
try {
@@ -314,7 +314,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}
- private def stopDaemon() {
+ private def stopDaemon(): Unit = {
self.synchronized {
if (useDaemon) {
cleanupIdleWorkers()
@@ -332,11 +332,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}
- def stop() {
+ def stop(): Unit = {
stopDaemon()
}
- def stopWorker(worker: Socket) {
+ def stopWorker(worker: Socket): Unit = {
self.synchronized {
if (useDaemon) {
if (daemon != null) {
@@ -355,7 +355,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
worker.close()
}
- def releaseWorker(worker: Socket) {
+ def releaseWorker(worker: Socket): Unit = {
if (useDaemon) {
self.synchronized {
lastActivityNs = System.nanoTime()
diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
index 86965dbc2e778..4e790b364e1d2 100644
--- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
@@ -37,11 +37,11 @@ case class TestWritable(var str: String, var int: Int, var double: Double) exten
def this() = this("", 0, 0.0)
def getStr: String = str
- def setStr(str: String) { this.str = str }
+ def setStr(str: String): Unit = { this.str = str }
def getInt: Int = int
- def setInt(int: Int) { this.int = int }
+ def setInt(int: Int): Unit = { this.int = int }
def getDouble: Double = double
- def setDouble(double: Double) { this.double = double }
+ def setDouble(double: Double): Unit = { this.double = double }
def write(out: DataOutput): Unit = {
out.writeUTF(str)
@@ -106,13 +106,13 @@ private[python] class WritableToDoubleArrayConverter extends Converter[Any, Arra
*/
object WriteInputFormatTestDataGenerator {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val path = args(0)
val sc = new JavaSparkContext("local[4]", "test-writables")
generateData(path, sc)
}
- def generateData(path: String, jsc: JavaSparkContext) {
+ def generateData(path: String, jsc: JavaSparkContext): Unit = {
val sc = jsc.sc
val basePath = s"$path/sftestdata/"
diff --git a/core/src/main/scala/org/apache/spark/api/r/BaseRRunner.scala b/core/src/main/scala/org/apache/spark/api/r/BaseRRunner.scala
index f96c5215cf0af..d8f9d1f1729b7 100644
--- a/core/src/main/scala/org/apache/spark/api/r/BaseRRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/BaseRRunner.scala
@@ -230,7 +230,7 @@ private[spark] class BufferedStreamThread(
errBufferSize: Int) extends Thread(name) with Logging {
val lines = new Array[String](errBufferSize)
var lineIdx = 0
- override def run() {
+ override def run(): Unit = {
for (line <- Source.fromInputStream(in).getLines) {
synchronized {
lines(lineIdx) = line
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
index 0e81ad198db67..9ef6c7c5906a2 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -74,7 +74,7 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Lo
* Asynchronously delete cached copies of this broadcast on the executors.
* If the broadcast is used after this is called, it will need to be re-sent to each executor.
*/
- def unpersist() {
+ def unpersist(): Unit = {
unpersist(blocking = false)
}
@@ -83,7 +83,7 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Lo
* this is called, it will need to be re-sent to each executor.
* @param blocking Whether to block until unpersisting has completed
*/
- def unpersist(blocking: Boolean) {
+ def unpersist(blocking: Boolean): Unit = {
assertValid()
doUnpersist(blocking)
}
@@ -93,7 +93,7 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Lo
* Destroy all data and metadata related to this broadcast variable. Use this with caution;
* once a broadcast variable has been destroyed, it cannot be used again.
*/
- def destroy() {
+ def destroy(): Unit = {
destroy(blocking = false)
}
@@ -102,7 +102,7 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Lo
* once a broadcast variable has been destroyed, it cannot be used again.
* @param blocking Whether to block until destroy has completed
*/
- private[spark] def destroy(blocking: Boolean) {
+ private[spark] def destroy(blocking: Boolean): Unit = {
assertValid()
_isValid = false
_destroySite = Utils.getCallSite().shortForm
@@ -128,17 +128,17 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Lo
* Actually unpersist the broadcasted value on the executors. Concrete implementations of
* Broadcast class must define their own logic to unpersist their own data.
*/
- protected def doUnpersist(blocking: Boolean)
+ protected def doUnpersist(blocking: Boolean): Unit
/**
* Actually destroy all data and metadata related to this broadcast variable.
* Implementation of Broadcast class must define their own logic to destroy their own
* state.
*/
- protected def doDestroy(blocking: Boolean)
+ protected def doDestroy(blocking: Boolean): Unit
/** Check if this broadcast is valid. If not valid, exception is thrown. */
- protected def assertValid() {
+ protected def assertValid(): Unit = {
if (!_isValid) {
throw new SparkException(
"Attempted to use %s after it was destroyed (%s) ".format(toString, _destroySite))
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
index 9fa47451c1831..c93cadf1ab3e8 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
@@ -40,7 +40,7 @@ private[spark] class BroadcastManager(
initialize()
// Called by SparkContext or Executor before using Broadcast
- private def initialize() {
+ private def initialize(): Unit = {
synchronized {
if (!initialized) {
broadcastFactory = new TorrentBroadcastFactory
@@ -50,7 +50,7 @@ private[spark] class BroadcastManager(
}
}
- def stop() {
+ def stop(): Unit = {
broadcastFactory.stop()
}
@@ -77,7 +77,7 @@ private[spark] class BroadcastManager(
broadcastFactory.newBroadcast[T](value_, isLocal, bid)
}
- def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
+ def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = {
broadcastFactory.unbroadcast(id, removeFromDriver, blocking)
}
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 1379314ba1b53..77fbbc08c2103 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -73,7 +73,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
/** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */
@transient private var blockSize: Int = _
- private def setConf(conf: SparkConf) {
+ private def setConf(conf: SparkConf): Unit = {
compressionCodec = if (conf.get(config.BROADCAST_COMPRESS)) {
Some(CompressionCodec.createCodec(conf))
} else {
@@ -196,7 +196,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
/**
* Remove all persisted state associated with this Torrent broadcast on the executors.
*/
- override protected def doUnpersist(blocking: Boolean) {
+ override protected def doUnpersist(blocking: Boolean): Unit = {
TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking)
}
@@ -204,7 +204,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
* Remove all persisted state associated with this Torrent broadcast on the executors
* and driver.
*/
- override protected def doDestroy(blocking: Boolean) {
+ override protected def doDestroy(blocking: Boolean): Unit = {
TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
index b11f9ba171b84..65fb5186afae1 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
@@ -28,20 +28,21 @@ import org.apache.spark.{SecurityManager, SparkConf}
*/
private[spark] class TorrentBroadcastFactory extends BroadcastFactory {
- override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { }
+ override def initialize(isDriver: Boolean, conf: SparkConf,
+ securityMgr: SecurityManager): Unit = { }
override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
new TorrentBroadcast[T](value_, id)
}
- override def stop() { }
+ override def stop(): Unit = { }
/**
* Remove all persisted state associated with the torrent broadcast with the given ID.
* @param removeFromDriver Whether to remove state from the driver.
* @param blocking Whether to block until unbroadcasted
*/
- override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
+ override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = {
TorrentBroadcast.unpersist(id, removeFromDriver, blocking)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index 648a8b1c763db..7022b986ea025 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -219,7 +219,7 @@ private class ClientEndpoint(
* Executable utility for starting and terminating drivers inside of a standalone cluster.
*/
object Client {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
// scalastyle:off println
if (!sys.props.contains("SPARK_SUBMIT")) {
println("WARNING: This client is deprecated and will be removed in a future version of Spark")
diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
index a86ee66fb72b9..9d6bbf91168da 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
@@ -100,7 +100,7 @@ private[deploy] class ClientArguments(args: Array[String]) {
/**
* Print usage and exit JVM with the given exit code.
*/
- private def printUsageAndExit(exitCode: Int) {
+ private def printUsageAndExit(exitCode: Int): Unit = {
// TODO: It wouldn't be too hard to allow users to submit their app and dependency jars
// separately similar to in the YARN client.
val usage =
diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
index 64277e8de2a4d..ebfff89308886 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
@@ -87,14 +87,14 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
}
/** Starts the external shuffle service if the user has configured us to. */
- def startIfEnabled() {
+ def startIfEnabled(): Unit = {
if (enabled) {
start()
}
}
/** Start the external shuffle service */
- def start() {
+ def start(): Unit = {
require(server == null, "Shuffle server already started")
val authEnabled = securityManager.isAuthenticationEnabled()
logInfo(s"Starting shuffle service on port $port (auth enabled = $authEnabled)")
@@ -125,7 +125,7 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
blockHandler.executorRemoved(executorId, appId)
}
- def stop() {
+ def stop(): Unit = {
if (server != null) {
server.close()
server = null
diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
index 99f841234005e..6ff68b694f8f3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
@@ -78,7 +78,7 @@ private object FaultToleranceTest extends App with Logging {
System.setProperty(config.DRIVER_HOST_ADDRESS.key, "172.17.42.1") // default docker host ip
- private def afterEach() {
+ private def afterEach(): Unit = {
if (sc != null) {
sc.stop()
sc = null
@@ -180,7 +180,7 @@ private object FaultToleranceTest extends App with Logging {
}
}
- private def test(name: String)(fn: => Unit) {
+ private def test(name: String)(fn: => Unit): Unit = {
try {
fn
numPassed += 1
@@ -198,12 +198,12 @@ private object FaultToleranceTest extends App with Logging {
afterEach()
}
- private def addMasters(num: Int) {
+ private def addMasters(num: Int): Unit = {
logInfo(s">>>>> ADD MASTERS $num <<<<<")
(1 to num).foreach { _ => masters += SparkDocker.startMaster(dockerMountDir) }
}
- private def addWorkers(num: Int) {
+ private def addWorkers(num: Int): Unit = {
logInfo(s">>>>> ADD WORKERS $num <<<<<")
val masterUrls = getMasterUrls(masters)
(1 to num).foreach { _ => workers += SparkDocker.startWorker(dockerMountDir, masterUrls) }
@@ -239,7 +239,7 @@ private object FaultToleranceTest extends App with Logging {
private def delay(secs: Duration = 5.seconds) = Thread.sleep(secs.toMillis)
- private def terminateCluster() {
+ private def terminateCluster(): Unit = {
logInfo(">>>>> TERMINATE CLUSTER <<<<<")
masters.foreach(_.kill())
workers.foreach(_.kill())
@@ -326,7 +326,7 @@ private object FaultToleranceTest extends App with Logging {
}
}
- private def assertTrue(bool: Boolean, message: String = "") {
+ private def assertTrue(bool: Boolean, message: String = ""): Unit = {
if (!bool) {
throw new IllegalStateException("Assertion failed: " + message)
}
@@ -346,7 +346,7 @@ private class TestMasterInfo(val ip: String, val dockerId: DockerId, val logFile
logDebug("Created master: " + this)
- def readState() {
+ def readState(): Unit = {
try {
val masterStream = new InputStreamReader(
new URL("http://%s:8080/json".format(ip)).openStream, StandardCharsets.UTF_8)
@@ -372,7 +372,7 @@ private class TestMasterInfo(val ip: String, val dockerId: DockerId, val logFile
}
}
- def kill() { Docker.kill(dockerId) }
+ def kill(): Unit = { Docker.kill(dockerId) }
override def toString: String =
"[ip=%s, id=%s, logFile=%s, state=%s]".
@@ -386,7 +386,7 @@ private class TestWorkerInfo(val ip: String, val dockerId: DockerId, val logFile
logDebug("Created worker: " + this)
- def kill() { Docker.kill(dockerId) }
+ def kill(): Unit = { Docker.kill(dockerId) }
override def toString: String =
"[ip=%s, id=%s, logFile=%s]".format(ip, dockerId, logFile.getAbsolutePath)
diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
index f1b58eb33a1b7..fc849d7f4372f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
@@ -72,7 +72,7 @@ class LocalSparkCluster(
masters
}
- def stop() {
+ def stop(): Unit = {
logInfo("Shutting down local Spark cluster.")
// Stop the workers before the master so they don't get upset that it disconnected
workerRpcEnvs.foreach(_.shutdown())
diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
index 8055a6270dac8..0c9d34986af63 100644
--- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.{RedirectThread, Utils}
* subprocess and then has it connect back to the JVM to access system properties, etc.
*/
object PythonRunner {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val pythonFile = args(0)
val pyFiles = args(1)
val otherArgs = args.slice(2, args.length)
diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
index 60ba0470a628a..b32f9ea3b4747 100644
--- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
@@ -73,7 +73,7 @@ object RRunner {
@volatile var sparkRBackendSecret: String = null
val initialized = new Semaphore(0)
val sparkRBackendThread = new Thread("SparkR backend") {
- override def run() {
+ override def run(): Unit = {
val (port, authHelper) = sparkRBackend.init()
sparkRBackendPort = port
sparkRBackendSecret = authHelper.secret
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala
index 8118c01eb712f..b89ae1b35e693 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala
@@ -45,7 +45,7 @@ private[spark] object SparkCuratorUtil extends Logging {
zk
}
- def mkdir(zk: CuratorFramework, path: String) {
+ def mkdir(zk: CuratorFramework, path: String): Unit = {
if (zk.checkExists().forPath(path) == null) {
try {
zk.create().creatingParentsIfNeeded().forPath(path)
@@ -57,7 +57,7 @@ private[spark] object SparkCuratorUtil extends Logging {
}
}
- def deleteRecursive(zk: CuratorFramework, path: String) {
+ def deleteRecursive(zk: CuratorFramework, path: String): Unit = {
if (zk.checkExists().forPath(path) != null) {
for (child <- zk.getChildren.forPath(path).asScala) {
zk.delete().forPath(path + "/" + child)
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 11420bb985520..1180501e8c738 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -57,7 +57,7 @@ private[spark] class SparkHadoopUtil extends Logging {
* you need to look https://issues.apache.org/jira/browse/HDFS-3545 and possibly
* do a FileSystem.closeAllForUGI in order to avoid leaking Filesystems
*/
- def runAsSparkUser(func: () => Unit) {
+ def runAsSparkUser(func: () => Unit): Unit = {
createSparkUser().doAs(new PrivilegedExceptionAction[Unit] {
def run: Unit = func()
})
@@ -71,7 +71,7 @@ private[spark] class SparkHadoopUtil extends Logging {
ugi
}
- def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) {
+ def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation): Unit = {
dest.addCredentials(source.getCredentials())
}
@@ -79,8 +79,10 @@ private[spark] class SparkHadoopUtil extends Logging {
* Appends S3-specific, spark.hadoop.*, and spark.buffer.size configurations to a Hadoop
* configuration.
*/
- def appendS3AndSparkHadoopConfigurations(conf: SparkConf, hadoopConf: Configuration): Unit = {
- SparkHadoopUtil.appendS3AndSparkHadoopConfigurations(conf, hadoopConf)
+ def appendS3AndSparkHadoopHiveConfigurations(
+ conf: SparkConf,
+ hadoopConf: Configuration): Unit = {
+ SparkHadoopUtil.appendS3AndSparkHadoopHiveConfigurations(conf, hadoopConf)
}
/**
@@ -103,6 +105,15 @@ private[spark] class SparkHadoopUtil extends Logging {
}
}
+ def appendSparkHiveConfigs(
+ srcMap: Map[String, String],
+ destMap: HashMap[String, String]): Unit = {
+ // Copy any "spark.hive.foo=bar" system properties into destMap as "hive.foo=bar"
+ for ((key, value) <- srcMap if key.startsWith("spark.hive.")) {
+ destMap.put(key.substring("spark.".length), value)
+ }
+ }
+
/**
* Return an appropriate (subclass) of Configuration. Creating config can initialize some Hadoop
* subsystems.
@@ -140,7 +151,7 @@ private[spark] class SparkHadoopUtil extends Logging {
* Add or overwrite current user's credentials with serialized delegation tokens,
* also confirms correct hadoop configuration is set.
*/
- private[spark] def addDelegationTokens(tokens: Array[Byte], sparkConf: SparkConf) {
+ private[spark] def addDelegationTokens(tokens: Array[Byte], sparkConf: SparkConf): Unit = {
UserGroupInformation.setConfiguration(newConfiguration(sparkConf))
val creds = deserialize(tokens)
logInfo("Updating delegation tokens for current user.")
@@ -413,11 +424,11 @@ private[spark] object SparkHadoopUtil {
*/
private[spark] def newConfiguration(conf: SparkConf): Configuration = {
val hadoopConf = new Configuration()
- appendS3AndSparkHadoopConfigurations(conf, hadoopConf)
+ appendS3AndSparkHadoopHiveConfigurations(conf, hadoopConf)
hadoopConf
}
- private def appendS3AndSparkHadoopConfigurations(
+ private def appendS3AndSparkHadoopHiveConfigurations(
conf: SparkConf,
hadoopConf: Configuration): Unit = {
// Note: this null check is around more than just access to the "conf" object to maintain
@@ -440,6 +451,7 @@ private[spark] object SparkHadoopUtil {
}
}
appendSparkHadoopConfigs(conf, hadoopConf)
+ appendSparkHiveConfigs(conf, hadoopConf)
val bufferSize = conf.get(BUFFER_SIZE).toString
hadoopConf.set("io.file.buffer.size", bufferSize)
}
@@ -452,37 +464,48 @@ private[spark] object SparkHadoopUtil {
}
}
+ private def appendSparkHiveConfigs(conf: SparkConf, hadoopConf: Configuration): Unit = {
+ // Copy any "spark.hive.foo=bar" spark properties into conf as "hive.foo=bar"
+ for ((key, value) <- conf.getAll if key.startsWith("spark.hive.")) {
+ hadoopConf.set(key.substring("spark.".length), value)
+ }
+ }
+
// scalastyle:off line.size.limit
/**
- * Create a path that uses replication instead of erasure coding (ec), regardless of the default
- * configuration in hdfs for the given path. This can be helpful as hdfs ec doesn't support
- * hflush(), hsync(), or append()
+ * Create a file on the given file system, optionally making sure erasure coding is disabled.
+ *
+ * Disabling EC can be helpful as HDFS EC doesn't support hflush(), hsync(), or append().
* https://hadoop.apache.org/docs/r3.0.0/hadoop-project-dist/hadoop-hdfs/HDFSErasureCoding.html#Limitations
*/
// scalastyle:on line.size.limit
- def createNonECFile(fs: FileSystem, path: Path): FSDataOutputStream = {
- try {
- // Use reflection as this uses APIs only available in Hadoop 3
- val builderMethod = fs.getClass().getMethod("createFile", classOf[Path])
- // the builder api does not resolve relative paths, nor does it create parent dirs, while
- // the old api does.
- if (!fs.mkdirs(path.getParent())) {
- throw new IOException(s"Failed to create parents of $path")
+ def createFile(fs: FileSystem, path: Path, allowEC: Boolean): FSDataOutputStream = {
+ if (allowEC) {
+ fs.create(path)
+ } else {
+ try {
+ // Use reflection as this uses APIs only available in Hadoop 3
+ val builderMethod = fs.getClass().getMethod("createFile", classOf[Path])
+ // the builder api does not resolve relative paths, nor does it create parent dirs, while
+ // the old api does.
+ if (!fs.mkdirs(path.getParent())) {
+ throw new IOException(s"Failed to create parents of $path")
+ }
+ val qualifiedPath = fs.makeQualified(path)
+ val builder = builderMethod.invoke(fs, qualifiedPath)
+ val builderCls = builder.getClass()
+ // this may throw a NoSuchMethodException if the path is not on hdfs
+ val replicateMethod = builderCls.getMethod("replicate")
+ val buildMethod = builderCls.getMethod("build")
+ val b2 = replicateMethod.invoke(builder)
+ buildMethod.invoke(b2).asInstanceOf[FSDataOutputStream]
+ } catch {
+ case _: NoSuchMethodException =>
+ // No createFile() method, we're using an older hdfs client, which doesn't give us control
+ // over EC vs. replication. Older hdfs doesn't have EC anyway, so just create a file with
+ // old apis.
+ fs.create(path)
}
- val qualifiedPath = fs.makeQualified(path)
- val builder = builderMethod.invoke(fs, qualifiedPath)
- val builderCls = builder.getClass()
- // this may throw a NoSuchMethodException if the path is not on hdfs
- val replicateMethod = builderCls.getMethod("replicate")
- val buildMethod = builderCls.getMethod("build")
- val b2 = replicateMethod.invoke(builder)
- buildMethod.invoke(b2).asInstanceOf[FSDataOutputStream]
- } catch {
- case _: NoSuchMethodException =>
- // No createFile() method, we're using an older hdfs client, which doesn't give us control
- // over EC vs. replication. Older hdfs doesn't have EC anyway, so just create a file with
- // old apis.
- fs.create(path)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 12a8473b22025..b776ec8f81e06 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -229,10 +229,6 @@ private[spark] class SparkSubmit extends Logging {
// Set the cluster manager
val clusterManager: Int = args.master match {
case "yarn" => YARN
- case "yarn-client" | "yarn-cluster" =>
- logWarning(s"Master ${args.master} is deprecated since 2.0." +
- " Please use master \"yarn\" with specified deploy mode instead.")
- YARN
case m if m.startsWith("spark") => STANDALONE
case m if m.startsWith("mesos") => MESOS
case m if m.startsWith("k8s") => KUBERNETES
@@ -251,22 +247,7 @@ private[spark] class SparkSubmit extends Logging {
-1
}
- // Because the deprecated way of specifying "yarn-cluster" and "yarn-client" encapsulate both
- // the master and deploy mode, we have some logic to infer the master and deploy mode
- // from each other if only one is specified, or exit early if they are at odds.
if (clusterManager == YARN) {
- (args.master, args.deployMode) match {
- case ("yarn-cluster", null) =>
- deployMode = CLUSTER
- args.master = "yarn"
- case ("yarn-cluster", "client") =>
- error("Client deploy mode is not compatible with master \"yarn-cluster\"")
- case ("yarn-client", "cluster") =>
- error("Cluster deploy mode is not compatible with master \"yarn-client\"")
- case (_, mode) =>
- args.master = "yarn"
- }
-
// Make sure YARN is included in our build if we're trying to use it
if (!Utils.classIsLoadable(YARN_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) {
error(
@@ -1047,7 +1028,7 @@ object SparkSubmit extends CommandLineUtils with Logging {
* Return whether the given primary resource requires running R.
*/
private[deploy] def isR(res: String): Boolean = {
- res != null && res.endsWith(".R") || res == SPARKR_SHELL
+ res != null && (res.endsWith(".R") || res.endsWith(".r")) || res == SPARKR_SHELL
}
private[deploy] def isInternal(res: String): Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala
index 34ade4ce6f39b..8f17159228f8b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala
@@ -120,7 +120,7 @@ private[spark] class StandaloneAppClient(
*
* nthRetry means this is the nth attempt to register with master.
*/
- private def registerWithMaster(nthRetry: Int) {
+ private def registerWithMaster(nthRetry: Int): Unit = {
registerMasterFutures.set(tryRegisterAllMasters())
registrationRetryTimer.set(registrationRetryThread.schedule(new Runnable {
override def run(): Unit = {
@@ -246,14 +246,14 @@ private[spark] class StandaloneAppClient(
/**
* Notify the listener that we disconnected, if we hadn't already done so before.
*/
- def markDisconnected() {
+ def markDisconnected(): Unit = {
if (!alreadyDisconnected) {
listener.disconnected()
alreadyDisconnected = true
}
}
- def markDead(reason: String) {
+ def markDead(reason: String): Unit = {
if (!alreadyDead.get) {
listener.dead(reason)
alreadyDead.set(true)
@@ -271,12 +271,12 @@ private[spark] class StandaloneAppClient(
}
- def start() {
+ def start(): Unit = {
// Just launch an rpcEndpoint; it will call back into the listener.
endpoint.set(rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv)))
}
- def stop() {
+ def stop(): Unit = {
if (endpoint.get != null) {
try {
val timeout = RpcUtils.askRpcTimeout(conf)
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala
index 8c63fa65b40fd..fb2a67c2ab103 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala
@@ -209,9 +209,8 @@ private[history] class ApplicationCache(
/**
* Register a filter for the web UI which checks for updates to the given app/attempt
- * @param ui Spark UI to attach filters to
- * @param appId application ID
- * @param attemptId attempt ID
+ * @param key consisted of appId and attemptId
+ * @param loadedUI Spark UI to attach filters to
*/
private def registerFilter(key: CacheKey, loadedUI: LoadedAppUI): Unit = {
require(loadedUI != null)
@@ -231,7 +230,7 @@ private[history] class ApplicationCache(
/**
* An entry in the cache.
*
- * @param ui Spark UI
+ * @param loadedUI Spark UI
* @param completed Flag to indicated that the application has completed (and so
* does not need refreshing).
*/
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
index f1c06205bf04c..472b52957ed7f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
@@ -114,6 +114,12 @@ private[history] abstract class ApplicationHistoryProvider {
*/
def stop(): Unit = { }
+ /**
+ * Called when the server is starting up. Implement this function to init the provider and start
+ * background threads. With this function we can start provider later after it is created.
+ */
+ def start(): Unit = { }
+
/**
* Returns configuration data to be shown in the History Server home page.
*
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 5f9b18ce01279..dce9581be2905 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -200,7 +200,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
}
}
- val initThread = initialize()
+ var initThread: Thread = null
private[history] def initialize(): Thread = {
if (!isFsInSafeMode()) {
@@ -384,6 +384,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
Map("Event log directory" -> logDir.toString) ++ safeMode
}
+ override def start(): Unit = {
+ initThread = initialize()
+ }
+
override def stop(): Unit = {
try {
if (initThread != null && initThread.isAlive()) {
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index 878f0cb632c5a..62cac261ae014 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -135,7 +135,7 @@ class HistoryServer(
* This starts a background thread that periodically synchronizes information displayed on
* this UI with the event logs in the provided base directory.
*/
- def initialize() {
+ def initialize(): Unit = {
attachPage(new HistoryPage(this))
attachHandler(ApiRootResource.getServletHandler(this))
@@ -149,12 +149,12 @@ class HistoryServer(
}
/** Bind to the HTTP server behind this web interface. */
- override def bind() {
+ override def bind(): Unit = {
super.bind()
}
/** Stop the server and close the file system. */
- override def stop() {
+ override def stop(): Unit = {
super.stop()
provider.stop()
}
@@ -164,7 +164,7 @@ class HistoryServer(
appId: String,
attemptId: Option[String],
ui: SparkUI,
- completed: Boolean) {
+ completed: Boolean): Unit = {
assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs")
ui.getHandlers.foreach { handler =>
serverInfo.get.addHandler(handler, ui.securityManager)
@@ -297,6 +297,7 @@ object HistoryServer extends Logging {
val server = new HistoryServer(conf, provider, securityManager, port)
server.bind()
+ provider.start()
ShutdownHookManager.addShutdownHook { () => server.stop() }
@@ -326,7 +327,7 @@ object HistoryServer extends Logging {
new SecurityManager(config)
}
- def initSecurity() {
+ def initSecurity(): Unit = {
// If we are accessing HDFS and it has security enabled (Kerberos), we have to login
// from a keytab file so that we can access HDFS beyond the kerberos ticket expiration.
// As long as it is using Hadoop rpc (hdfs://), a relogin will automatically
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
index dec89769c030b..01cc59e1d2e6e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
@@ -52,7 +52,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin
// This mutates the SparkConf, so all accesses to it must be made after this line
Utils.loadDefaultSparkProperties(conf, propertiesFile)
- private def printUsageAndExit(exitCode: Int) {
+ private def printUsageAndExit(exitCode: Int): Unit = {
// scalastyle:off println
System.err.println(
"""
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index 6c56807458b27..03965e6dbbf31 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -57,7 +57,7 @@ private[spark] class ApplicationInfo(
init()
}
- private def init() {
+ private def init(): Unit = {
state = ApplicationState.WAITING
executors = new mutable.HashMap[Int, ExecutorDesc]
coresGranted = 0
@@ -92,7 +92,7 @@ private[spark] class ApplicationInfo(
exec
}
- private[master] def removeExecutor(exec: ExecutorDesc) {
+ private[master] def removeExecutor(exec: ExecutorDesc): Unit = {
if (executors.contains(exec.id)) {
removedExecutors += executors(exec.id)
executors -= exec.id
@@ -115,7 +115,7 @@ private[spark] class ApplicationInfo(
private[master] def resetRetryCount() = _retryCount = 0
- private[master] def markFinished(endState: ApplicationState.Value) {
+ private[master] def markFinished(endState: ApplicationState.Value): Unit = {
state = endState
endTime = System.currentTimeMillis()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala
index a8f8492561115..a598d2a1ddd76 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala
@@ -33,7 +33,7 @@ private[master] class ExecutorDesc(
var state = ExecutorState.LAUNCHING
/** Copy all state (non-val) variables from the given on-the-wire ExecutorDescription. */
- def copyState(execDesc: ExecutorDescription) {
+ def copyState(execDesc: ExecutorDescription): Unit = {
state = execDesc.state
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index f2b5ea7e23ec1..ba949e2630e43 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -56,7 +56,7 @@ private[master] class FileSystemPersistenceEngine(
files.map(deserializeFromFile[T])
}
- private def serializeIntoFile(file: File, value: AnyRef) {
+ private def serializeIntoFile(file: File, value: AnyRef): Unit = {
val created = file.createNewFile()
if (!created) { throw new IllegalStateException("Could not create file: " + file) }
val fileOut = new FileOutputStream(file)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
index 52e2854961eda..5bdfd18f37cd0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
@@ -27,7 +27,7 @@ import org.apache.spark.annotation.DeveloperApi
@DeveloperApi
trait LeaderElectionAgent {
val masterInstance: LeaderElectable
- def stop() {} // to avoid noops in implementations.
+ def stop(): Unit = {} // to avoid noops in implementations.
}
@DeveloperApi
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 5588dc8cff47a..8d3795cae707a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -192,7 +192,7 @@ private[deploy] class Master(
leaderElectionAgent = leaderElectionAgent_
}
- override def onStop() {
+ override def onStop(): Unit = {
masterMetricsSystem.report()
applicationMetricsSystem.report()
// prevent the CompleteRecovery message sending to restarted master
@@ -211,11 +211,11 @@ private[deploy] class Master(
leaderElectionAgent.stop()
}
- override def electedLeader() {
+ override def electedLeader(): Unit = {
self.send(ElectedLeader)
}
- override def revokedLeadership() {
+ override def revokedLeadership(): Unit = {
self.send(RevokedLeadership)
}
@@ -529,7 +529,7 @@ private[deploy] class Master(
apps.count(_.state == ApplicationState.UNKNOWN) == 0
private def beginRecovery(storedApps: Seq[ApplicationInfo], storedDrivers: Seq[DriverInfo],
- storedWorkers: Seq[WorkerInfo]) {
+ storedWorkers: Seq[WorkerInfo]): Unit = {
for (app <- storedApps) {
logInfo("Trying to recover app: " + app.id)
try {
@@ -559,7 +559,7 @@ private[deploy] class Master(
}
}
- private def completeRecovery() {
+ private def completeRecovery(): Unit = {
// Ensure "only-once" recovery semantics using a short synchronization period.
if (state != RecoveryState.RECOVERING) { return }
state = RecoveryState.COMPLETING_RECOVERY
@@ -850,7 +850,7 @@ private[deploy] class Master(
true
}
- private def removeWorker(worker: WorkerInfo, msg: String) {
+ private def removeWorker(worker: WorkerInfo, msg: String): Unit = {
logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port)
worker.setState(WorkerState.DEAD)
idToWorker -= worker.id
@@ -879,7 +879,7 @@ private[deploy] class Master(
persistenceEngine.removeWorker(worker)
}
- private def relaunchDriver(driver: DriverInfo) {
+ private def relaunchDriver(driver: DriverInfo): Unit = {
// We must setup a new driver with a new driver id here, because the original driver may
// be still running. Consider this scenario: a worker is network partitioned with master,
// the master then relaunches driver driverID1 with a driver id driverID2, then the worker
@@ -919,11 +919,11 @@ private[deploy] class Master(
waitingApps += app
}
- private def finishApplication(app: ApplicationInfo) {
+ private def finishApplication(app: ApplicationInfo): Unit = {
removeApplication(app, ApplicationState.FINISHED)
}
- def removeApplication(app: ApplicationInfo, state: ApplicationState.Value) {
+ def removeApplication(app: ApplicationInfo, state: ApplicationState.Value): Unit = {
if (apps.contains(app)) {
logInfo("Removing app " + app.id)
apps -= app
@@ -1047,7 +1047,7 @@ private[deploy] class Master(
}
/** Check for, and remove, any timed-out workers */
- private def timeOutDeadWorkers() {
+ private def timeOutDeadWorkers(): Unit = {
// Copy the workers into an array so we don't modify the hashset while iterating through it
val currentTime = System.currentTimeMillis()
val toRemove = workers.filter(_.lastHeartbeat < currentTime - workerTimeoutMs).toArray
@@ -1077,7 +1077,7 @@ private[deploy] class Master(
new DriverInfo(now, newDriverId(date), desc, date)
}
- private def launchDriver(worker: WorkerInfo, driver: DriverInfo) {
+ private def launchDriver(worker: WorkerInfo, driver: DriverInfo): Unit = {
logInfo("Launching driver " + driver.id + " on worker " + worker.id)
worker.addDriver(driver)
driver.worker = Some(worker)
@@ -1088,7 +1088,7 @@ private[deploy] class Master(
private def removeDriver(
driverId: String,
finalState: DriverState,
- exception: Option[Exception]) {
+ exception: Option[Exception]): Unit = {
drivers.find(d => d.id == driverId) match {
case Some(driver) =>
logInfo(s"Removing driver: $driverId")
@@ -1113,7 +1113,7 @@ private[deploy] object Master extends Logging {
val SYSTEM_NAME = "sparkMaster"
val ENDPOINT_NAME = "Master"
- def main(argStrings: Array[String]) {
+ def main(argStrings: Array[String]): Unit = {
Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler(
exitOnUncaughtException = false))
Utils.initDaemon(log)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
index cd31bbdcfab59..045a3da74dcd0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
@@ -94,7 +94,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) exte
/**
* Print usage and exit JVM with the given exit code.
*/
- private def printUsageAndExit(exitCode: Int) {
+ private def printUsageAndExit(exitCode: Int): Unit = {
// scalastyle:off println
System.err.println(
"Usage: Master [options]\n" +
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
index b30bc821b7324..9a695e15a9cea 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
@@ -88,7 +88,7 @@ abstract class PersistenceEngine {
}
}
- def close() {}
+ def close(): Unit = {}
}
private[master] class BlackHolePersistenceEngine extends PersistenceEngine {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
index a33b15354efea..48458819d641c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -18,9 +18,7 @@
package org.apache.spark.deploy.master
import scala.collection.mutable
-import scala.reflect.ClassTag
-import org.apache.spark.deploy.StandaloneResourceUtils.MutableResourceInfo
import org.apache.spark.resource.{ResourceAllocator, ResourceInformation, ResourceRequirement}
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils
@@ -93,7 +91,7 @@ private[spark] class WorkerInfo(
init()
}
- private def init() {
+ private def init(): Unit = {
executors = new mutable.HashMap
drivers = new mutable.HashMap
state = WorkerState.ALIVE
@@ -107,13 +105,13 @@ private[spark] class WorkerInfo(
host + ":" + port
}
- def addExecutor(exec: ExecutorDesc) {
+ def addExecutor(exec: ExecutorDesc): Unit = {
executors(exec.fullId) = exec
coresUsed += exec.cores
memoryUsed += exec.memory
}
- def removeExecutor(exec: ExecutorDesc) {
+ def removeExecutor(exec: ExecutorDesc): Unit = {
if (executors.contains(exec.fullId)) {
executors -= exec.fullId
coresUsed -= exec.cores
@@ -126,13 +124,13 @@ private[spark] class WorkerInfo(
executors.values.exists(_.application == app)
}
- def addDriver(driver: DriverInfo) {
+ def addDriver(driver: DriverInfo): Unit = {
drivers(driver.id) = driver
memoryUsed += driver.desc.mem
coresUsed += driver.desc.cores
}
- def removeDriver(driver: DriverInfo) {
+ def removeDriver(driver: DriverInfo): Unit = {
drivers -= driver.id
memoryUsed -= driver.desc.mem
coresUsed -= driver.desc.cores
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
index 47f309144bdc0..d4ae977b19f4b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
@@ -36,7 +36,7 @@ private[master] class ZooKeeperLeaderElectionAgent(val masterInstance: LeaderEle
start()
- private def start() {
+ private def start(): Unit = {
logInfo("Starting ZooKeeper LeaderElection agent")
zk = SparkCuratorUtil.newClient(conf)
leaderLatch = new LeaderLatch(zk, workingDir)
@@ -44,12 +44,12 @@ private[master] class ZooKeeperLeaderElectionAgent(val masterInstance: LeaderEle
leaderLatch.start()
}
- override def stop() {
+ override def stop(): Unit = {
leaderLatch.close()
zk.close()
}
- override def isLeader() {
+ override def isLeader(): Unit = {
synchronized {
// could have lost leadership by now.
if (!leaderLatch.hasLeadership) {
@@ -61,7 +61,7 @@ private[master] class ZooKeeperLeaderElectionAgent(val masterInstance: LeaderEle
}
}
- override def notLeader() {
+ override def notLeader(): Unit = {
synchronized {
// could have gained leadership by now.
if (leaderLatch.hasLeadership) {
@@ -73,7 +73,7 @@ private[master] class ZooKeeperLeaderElectionAgent(val masterInstance: LeaderEle
}
}
- private def updateLeadershipStatus(isLeader: Boolean) {
+ private def updateLeadershipStatus(isLeader: Boolean): Unit = {
if (isLeader && status == LeadershipStatus.NOT_LEADER) {
status = LeadershipStatus.LEADER
masterInstance.electedLeader()
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 73dd0de017960..8eae445b439d9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -55,11 +55,11 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer
.filter(_.startsWith(prefix)).flatMap(deserializeFromFile[T])
}
- override def close() {
+ override def close(): Unit = {
zk.close()
}
- private def serializeIntoFile(path: String, value: AnyRef) {
+ private def serializeIntoFile(path: String, value: AnyRef): Unit = {
val serialized = serializer.newInstance().serialize(value)
val bytes = new Array[Byte](serialized.remaining())
serialized.get(bytes)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index e8b614527f69c..042ec54ee1240 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -27,7 +27,6 @@ import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, MasterStateRe
import org.apache.spark.deploy.JsonProtocol
import org.apache.spark.deploy.StandaloneResourceUtils._
import org.apache.spark.deploy.master._
-import org.apache.spark.resource.ResourceInformation
import org.apache.spark.ui.{UIUtils, WebUIPage}
import org.apache.spark.util.Utils
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index be402ae247511..86554ec4ec1c9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -40,7 +40,7 @@ class MasterWebUI(
initialize()
/** Initialize all components of the server. */
- def initialize() {
+ def initialize(): Unit = {
val masterPage = new MasterPage(this)
attachPage(new ApplicationPage(this))
attachPage(masterPage)
diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala
index 759d857d56e0e..f769ce468e49c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala
@@ -140,13 +140,21 @@ private[spark] class HadoopDelegationTokenManager(
* @param creds Credentials object where to store the delegation tokens.
*/
def obtainDelegationTokens(creds: Credentials): Unit = {
- val freshUGI = doLogin()
- freshUGI.doAs(new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- val (newTokens, _) = obtainDelegationTokens()
- creds.addAll(newTokens)
- }
- })
+ val currentUser = UserGroupInformation.getCurrentUser()
+ val hasKerberosCreds = principal != null ||
+ Option(currentUser.getRealUser()).getOrElse(currentUser).hasKerberosCredentials()
+
+ // Delegation tokens can only be obtained if the real user has Kerberos credentials, so
+ // skip creation when those are not available.
+ if (hasKerberosCreds) {
+ val freshUGI = doLogin()
+ freshUGI.doAs(new PrivilegedExceptionAction[Unit]() {
+ override def run(): Unit = {
+ val (newTokens, _) = obtainDelegationTokens()
+ creds.addAll(newTokens)
+ }
+ })
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
index 12e0dae3f5e5a..f7423f1fc3f1c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
@@ -102,12 +102,12 @@ object CommandUtils extends Logging {
}
/** Spawn a thread that will redirect a given stream to a file */
- def redirectStream(in: InputStream, file: File) {
+ def redirectStream(in: InputStream, file: File): Unit = {
val out = new FileOutputStream(file, true)
// TODO: It would be nice to add a shutdown hook here that explains why the output is
// terminating. Otherwise if the worker dies the executor logs will silently stop.
new Thread("redirect output to " + file) {
- override def run() {
+ override def run(): Unit = {
try {
Utils.copyStream(in, out, true)
} catch {
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index 4934722c0d83e..53ec7b3a88f35 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -84,7 +84,7 @@ private[deploy] class DriverRunner(
/** Starts a thread to run and manage the driver. */
private[worker] def start() = {
new Thread("DriverRunner for " + driverId) {
- override def run() {
+ override def run(): Unit = {
var shutdownHook: AnyRef = null
try {
shutdownHook = ShutdownHookManager.addShutdownHook { () =>
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
index 56356f5f27e27..45ffdde58d6c3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
@@ -32,7 +32,7 @@ import org.apache.spark.util._
* This is used in standalone cluster mode only.
*/
object DriverWrapper extends Logging {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
args.toList match {
/*
* IMPORTANT: Spark 1.3 provides a stable application submission gateway that is both
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 97939107f3057..2a5528bbe89cb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -31,7 +31,7 @@ import org.apache.spark.deploy.StandaloneResourceUtils.prepareResourcesFile
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.SPARK_EXECUTOR_PREFIX
import org.apache.spark.internal.config.UI._
-import org.apache.spark.resource.{ResourceInformation, ResourceUtils}
+import org.apache.spark.resource.ResourceInformation
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.{ShutdownHookManager, Utils}
import org.apache.spark.util.logging.FileAppender
@@ -74,9 +74,9 @@ private[deploy] class ExecutorRunner(
// make sense to remove this in the future.
private var shutdownHook: AnyRef = null
- private[worker] def start() {
+ private[worker] def start(): Unit = {
workerThread = new Thread("ExecutorRunner for " + fullId) {
- override def run() { fetchAndRunExecutor() }
+ override def run(): Unit = { fetchAndRunExecutor() }
}
workerThread.start()
// Shutdown hook that kills actors on shutdown.
@@ -94,7 +94,7 @@ private[deploy] class ExecutorRunner(
*
* @param message the exception message which caused the executor's death
*/
- private def killProcess(message: Option[String]) {
+ private def killProcess(message: Option[String]): Unit = {
var exitCode: Option[Int] = None
if (process != null) {
logInfo("Killing process!")
@@ -118,7 +118,7 @@ private[deploy] class ExecutorRunner(
}
/** Stop this executor runner, including killing the process it launched */
- private[worker] def kill() {
+ private[worker] def kill(): Unit = {
if (workerThread != null) {
// the workerThread will kill the child process when interrupted
workerThread.interrupt()
@@ -145,7 +145,7 @@ private[deploy] class ExecutorRunner(
/**
* Download and run the executor described in our ApplicationDescription
*/
- private def fetchAndRunExecutor() {
+ private def fetchAndRunExecutor(): Unit = {
try {
val resourceFileOpt = prepareResourcesFile(SPARK_EXECUTOR_PREFIX, resources, executorDir)
// Launch the process
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 3731b6aec6522..4be495ac4f13f 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -190,14 +190,14 @@ private[deploy] class Worker(
def coresFree: Int = cores - coresUsed
def memoryFree: Int = memory - memoryUsed
- private def createWorkDir() {
+ private def createWorkDir(): Unit = {
workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work"))
if (!Utils.createDirectory(workDir)) {
System.exit(1)
}
}
- override def onStart() {
+ override def onStart(): Unit = {
assert(!registered)
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
host, port, cores, Utils.megabytesToString(memory)))
@@ -268,7 +268,8 @@ private[deploy] class Worker(
* @param masterAddress the new master address which the worker should use to connect in case of
* failure
*/
- private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String, masterAddress: RpcAddress) {
+ private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String,
+ masterAddress: RpcAddress): Unit = {
// activeMasterUrl it's a valid Spark url since we receive it from master.
activeMasterUrl = masterRef.address.toSparkURL
activeMasterWebUiUrl = uiUrl
@@ -391,7 +392,7 @@ private[deploy] class Worker(
registrationRetryTimer = None
}
- private def registerWithMaster() {
+ private def registerWithMaster(): Unit = {
// onDisconnected may be triggered multiple times, so don't attempt registration
// if there are outstanding registration attempts scheduled.
registrationRetryTimer match {
@@ -410,7 +411,7 @@ private[deploy] class Worker(
}
}
- private def startExternalShuffleService() {
+ private def startExternalShuffleService(): Unit = {
try {
shuffleService.startIfEnabled()
} catch {
@@ -690,7 +691,7 @@ private[deploy] class Worker(
}
}
- private def masterDisconnected() {
+ private def masterDisconnected(): Unit = {
logError("Connection to master failed! Waiting for master to reconnect...")
connected = false
registerWithMaster()
@@ -736,7 +737,7 @@ private[deploy] class Worker(
"worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port)
}
- override def onStop() {
+ override def onStop(): Unit = {
releaseResources(conf, SPARK_WORKER_PREFIX, resources, pid)
cleanupThreadExecutor.shutdownNow()
metricsSystem.report()
@@ -834,7 +835,7 @@ private[deploy] object Worker extends Logging {
val ENDPOINT_NAME = "Worker"
private val SSL_NODE_LOCAL_CONFIG_PATTERN = """\-Dspark\.ssl\.useNodeLocalConf\=(.+)""".r
- def main(argStrings: Array[String]) {
+ def main(argStrings: Array[String]): Unit = {
Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler(
exitOnUncaughtException = false))
Utils.initDaemon(log)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
index 8c87708e960e6..42f684c0a1973 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
@@ -122,7 +122,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
/**
* Print usage and exit JVM with the given exit code.
*/
- def printUsageAndExit(exitCode: Int) {
+ def printUsageAndExit(exitCode: Int): Unit = {
// scalastyle:off println
System.err.println(
"Usage: Worker [options]
\n" +
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index 96980c3ff0331..0f5e96c558490 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -43,7 +43,7 @@ class WorkerWebUI(
initialize()
/** Initialize all components of the server. */
- def initialize() {
+ def initialize(): Unit = {
val logPage = new LogPage(this)
attachPage(logPage)
attachPage(new WorkerPage(this))
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index e96c41a61b066..fbf2dc73ea075 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -70,7 +70,7 @@ private[spark] class CoarseGrainedExecutorBackend(
*/
private[executor] val taskResources = new mutable.HashMap[Long, Map[String, ResourceInformation]]
- override def onStart() {
+ override def onStart(): Unit = {
logInfo("Connecting to driver: " + driverUrl)
val resources = parseOrFindResources(resourcesFileOpt)
rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
@@ -186,7 +186,7 @@ private[spark] class CoarseGrainedExecutorBackend(
}
}
- override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
+ override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer): Unit = {
val resources = taskResources.getOrElse(taskId, Map.empty[String, ResourceInformation])
val msg = StatusUpdate(executorId, taskId, state, data, resources)
if (TaskState.isFinished(state)) {
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index c337d24381286..ce6d0322bafd5 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -139,20 +139,26 @@ private[spark] class Executor(
private val executorPlugins: Seq[ExecutorPlugin] = {
val pluginNames = conf.get(EXECUTOR_PLUGINS)
if (pluginNames.nonEmpty) {
- logDebug(s"Initializing the following plugins: ${pluginNames.mkString(", ")}")
+ logInfo(s"Initializing the following plugins: ${pluginNames.mkString(", ")}")
// Plugins need to load using a class loader that includes the executor's user classpath
val pluginList: Seq[ExecutorPlugin] =
Utils.withContextClassLoader(replClassLoader) {
val plugins = Utils.loadExtensions(classOf[ExecutorPlugin], pluginNames, conf)
plugins.foreach { plugin =>
- plugin.init()
- logDebug(s"Successfully loaded plugin " + plugin.getClass().getCanonicalName())
+ val pluginSource = new ExecutorPluginSource(plugin.getClass().getSimpleName())
+ val pluginContext = new ExecutorPluginContext(pluginSource.metricRegistry, conf,
+ executorId, executorHostname, isLocal)
+ plugin.init(pluginContext)
+ logInfo("Successfully loaded plugin " + plugin.getClass().getCanonicalName())
+ if (pluginSource.metricRegistry.getNames.size() > 0) {
+ env.metricsSystem.registerSource(pluginSource)
+ }
}
plugins
}
- logDebug("Finished initializing plugins")
+ logInfo("Finished initializing plugins")
pluginList
} else {
Nil
@@ -623,6 +629,11 @@ private[spark] class Executor(
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
+ case t: Throwable if env.isStopped =>
+ // Log the expected exception after executor.stop without stack traces
+ // see: SPARK-19147
+ logError(s"Exception in $taskName (TID $taskId): ${t.getMessage}")
+
case t: Throwable =>
// Attempt to exit cleanly by informing the driver of our failure.
// If anything goes wrong (or this was a fatal exception), we will delegate to
@@ -846,7 +857,7 @@ private[spark] class Executor(
* Download any missing dependencies if we receive a new set of files and JARs from the
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
- private def updateDependencies(newFiles: Map[String, Long], newJars: Map[String, Long]) {
+ private def updateDependencies(newFiles: Map[String, Long], newJars: Map[String, Long]): Unit = {
lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
synchronized {
// Fetch missing dependencies
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/HiveTestUtils.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorPluginSource.scala
similarity index 64%
rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/test/HiveTestUtils.scala
rename to core/src/main/scala/org/apache/spark/executor/ExecutorPluginSource.scala
index 7631efedf46af..5625e953c5e67 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/HiveTestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorPluginSource.scala
@@ -15,18 +15,16 @@
* limitations under the License.
*/
-package org.apache.spark.sql.hive.test
+package org.apache.spark.executor
-import java.io.File
+import com.codahale.metrics.MetricRegistry
-import org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax
-import org.apache.hive.hcatalog.data.JsonSerDe
+import org.apache.spark.metrics.source.Source
-object HiveTestUtils {
+private[spark]
+class ExecutorPluginSource(name: String) extends Source {
- val getHiveContribJar: File =
- new File(classOf[UDAFExampleMax].getProtectionDomain.getCodeSource.getLocation.getPath)
+ override val metricRegistry = new MetricRegistry()
- val getHiveHcatalogCoreJar: File =
- new File(classOf[JsonSerDe].getProtectionDomain.getCodeSource.getLocation.getPath)
+ override val sourceName = name
}
diff --git a/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala b/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala
index 2111273d8b35a..0d5dcfb43cbfd 100644
--- a/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala
@@ -18,7 +18,6 @@
package org.apache.spark.executor
import java.io._
-import java.nio.charset.Charset
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.{Files, Paths}
import java.util.Locale
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index ea79c7310349d..1470a23884bb0 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -137,6 +137,7 @@ class TaskMetrics private[spark] () extends Serializable {
private[spark] def setJvmGCTime(v: Long): Unit = _jvmGCTime.setValue(v)
private[spark] def setResultSerializationTime(v: Long): Unit =
_resultSerializationTime.setValue(v)
+ private[spark] def setPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.setValue(v)
private[spark] def incMemoryBytesSpilled(v: Long): Unit = _memoryBytesSpilled.add(v)
private[spark] def incDiskBytesSpilled(v: Long): Unit = _diskBytesSpilled.add(v)
private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v)
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
index 549395314ba61..f6902d1bf83a1 100644
--- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
@@ -46,7 +46,7 @@ private[spark] class FixedLengthBinaryRecordReader
private var recordKey: LongWritable = null
private var recordValue: BytesWritable = null
- override def close() {
+ override def close(): Unit = {
if (fileInputStream != null) {
fileInputStream.close()
}
@@ -69,7 +69,7 @@ private[spark] class FixedLengthBinaryRecordReader
}
}
- override def initialize(inputSplit: InputSplit, context: TaskAttemptContext) {
+ override def initialize(inputSplit: InputSplit, context: TaskAttemptContext): Unit = {
// the file input
val fileSplit = inputSplit.asInstanceOf[FileSplit]
diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
index 6a4af01475646..57210da6a48eb 100644
--- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
+++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
@@ -44,7 +44,7 @@ private[spark] abstract class StreamFileInputFormat[T]
* Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API
* which is set through setMaxSplitSize
*/
- def setMinPartitions(sc: SparkContext, context: JobContext, minPartitions: Int) {
+ def setMinPartitions(sc: SparkContext, context: JobContext, minPartitions: Int): Unit = {
val defaultMaxSplitBytes = sc.getConf.get(config.FILES_MAX_PARTITION_BYTES)
val openCostInBytes = sc.getConf.get(config.FILES_OPEN_COST_IN_BYTES)
val defaultParallelism = Math.max(sc.defaultParallelism, minPartitions)
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
index 04c5c4b90e8a1..692deb7a3282f 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
@@ -48,7 +48,7 @@ private[spark] class WholeTextFileInputFormat
* Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API,
* which is set through setMaxSplitSize
*/
- def setMinPartitions(context: JobContext, minPartitions: Int) {
+ def setMinPartitions(context: JobContext, minPartitions: Int): Unit = {
val files = listStatus(context).asScala
val totalLen = files.map(file => if (file.isDirectory) 0L else file.getLen).sum
val maxSplitSize = Math.ceil(totalLen * 1.0 /
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
index 28fd1ff1b77ca..0bd2d551cc912 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
@@ -31,7 +31,7 @@ import org.apache.hadoop.mapreduce.lib.input.{CombineFileRecordReader, CombineFi
*/
private[spark] trait Configurable extends HConfigurable {
private var conf: Configuration = _
- def setConf(c: Configuration) {
+ def setConf(c: Configuration): Unit = {
conf = c
}
def getConf: Configuration = conf
diff --git a/core/src/main/scala/org/apache/spark/internal/Logging.scala b/core/src/main/scala/org/apache/spark/internal/Logging.scala
index 0987917bac0e7..edfe9446094c8 100644
--- a/core/src/main/scala/org/apache/spark/internal/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/internal/Logging.scala
@@ -53,44 +53,44 @@ trait Logging {
}
// Log methods that take only a String
- protected def logInfo(msg: => String) {
+ protected def logInfo(msg: => String): Unit = {
if (log.isInfoEnabled) log.info(msg)
}
- protected def logDebug(msg: => String) {
+ protected def logDebug(msg: => String): Unit = {
if (log.isDebugEnabled) log.debug(msg)
}
- protected def logTrace(msg: => String) {
+ protected def logTrace(msg: => String): Unit = {
if (log.isTraceEnabled) log.trace(msg)
}
- protected def logWarning(msg: => String) {
+ protected def logWarning(msg: => String): Unit = {
if (log.isWarnEnabled) log.warn(msg)
}
- protected def logError(msg: => String) {
+ protected def logError(msg: => String): Unit = {
if (log.isErrorEnabled) log.error(msg)
}
// Log methods that take Throwables (Exceptions/Errors) too
- protected def logInfo(msg: => String, throwable: Throwable) {
+ protected def logInfo(msg: => String, throwable: Throwable): Unit = {
if (log.isInfoEnabled) log.info(msg, throwable)
}
- protected def logDebug(msg: => String, throwable: Throwable) {
+ protected def logDebug(msg: => String, throwable: Throwable): Unit = {
if (log.isDebugEnabled) log.debug(msg, throwable)
}
- protected def logTrace(msg: => String, throwable: Throwable) {
+ protected def logTrace(msg: => String, throwable: Throwable): Unit = {
if (log.isTraceEnabled) log.trace(msg, throwable)
}
- protected def logWarning(msg: => String, throwable: Throwable) {
+ protected def logWarning(msg: => String, throwable: Throwable): Unit = {
if (log.isWarnEnabled) log.warn(msg, throwable)
}
- protected def logError(msg: => String, throwable: Throwable) {
+ protected def logError(msg: => String, throwable: Throwable): Unit = {
if (log.isErrorEnabled) log.error(msg, throwable)
}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/UI.scala b/core/src/main/scala/org/apache/spark/internal/config/UI.scala
index a11970ec73d88..1a8268161160b 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/UI.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/UI.scala
@@ -81,6 +81,13 @@ private[spark] object UI {
.booleanConf
.createWithDefault(true)
+ val UI_PROMETHEUS_ENABLED = ConfigBuilder("spark.ui.prometheus.enabled")
+ .internal()
+ .doc("Expose executor metrics at /metrics/executors/prometheus. " +
+ "For master/worker/driver metrics, you need to configure `conf/metrics.properties`.")
+ .booleanConf
+ .createWithDefault(false)
+
val UI_X_XSS_PROTECTION = ConfigBuilder("spark.ui.xXssProtection")
.doc("Value for HTTP X-XSS-Protection response header")
.stringConf
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index b898413ac8d76..d142d22929728 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -106,6 +106,11 @@ package object config {
.booleanConf
.createWithDefault(false)
+ private[spark] val DRIVER_LOG_ALLOW_EC =
+ ConfigBuilder("spark.driver.log.allowErasureCoding")
+ .booleanConf
+ .createWithDefault(false)
+
private[spark] val EVENT_LOG_ENABLED = ConfigBuilder("spark.eventLog.enabled")
.booleanConf
.createWithDefault(false)
@@ -243,7 +248,8 @@ package object config {
.createWithDefault(false)
private[spark] val MEMORY_OFFHEAP_SIZE = ConfigBuilder("spark.memory.offHeap.size")
- .doc("The absolute amount of memory in bytes which can be used for off-heap allocation. " +
+ .doc("The absolute amount of memory which can be used for off-heap allocation, " +
+ " in bytes unless otherwise specified. " +
"This setting has no impact on heap memory usage, so if your executors' total memory " +
"consumption must fit within some hard limit then be sure to shrink your JVM heap size " +
"accordingly. This must be set to a positive value when spark.memory.offHeap.enabled=true.")
@@ -1026,7 +1032,7 @@ package object config {
.booleanConf
.createWithDefault(false)
- private[spark] val SHUFFLE_UNDAFE_FAST_MERGE_ENABLE =
+ private[spark] val SHUFFLE_UNSAFE_FAST_MERGE_ENABLE =
ConfigBuilder("spark.shuffle.unsafe.fastMergeEnabled")
.doc("Whether to perform a fast spill merge.")
.booleanConf
@@ -1047,6 +1053,14 @@ package object config {
.checkValue(v => v > 0, "The value should be a positive integer.")
.createWithDefault(2000)
+ private[spark] val SHUFFLE_USE_OLD_FETCH_PROTOCOL =
+ ConfigBuilder("spark.shuffle.useOldFetchProtocol")
+ .doc("Whether to use the old protocol while doing the shuffle block fetching. " +
+ "It is only enabled while we need the compatibility in the scenario of new Spark " +
+ "version job fetching shuffle blocks from old version external shuffle service.")
+ .booleanConf
+ .createWithDefault(false)
+
private[spark] val MEMORY_MAP_LIMIT_FOR_TESTS =
ConfigBuilder("spark.storage.memoryMapLimitForTests")
.internal()
diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
index adbd59c9f03b4..5205a2d568ac3 100644
--- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
+++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
@@ -44,6 +44,10 @@ trait CompressionCodec {
def compressedOutputStream(s: OutputStream): OutputStream
+ private[spark] def compressedContinuousOutputStream(s: OutputStream): OutputStream = {
+ compressedOutputStream(s)
+ }
+
def compressedInputStream(s: InputStream): InputStream
private[spark] def compressedContinuousInputStream(s: InputStream): InputStream = {
@@ -220,6 +224,12 @@ class ZStdCompressionCodec(conf: SparkConf) extends CompressionCodec {
new BufferedOutputStream(new ZstdOutputStream(s, level), bufferSize)
}
+ override private[spark] def compressedContinuousOutputStream(s: OutputStream) = {
+ // SPARK-29322: Set "closeFrameOnFlush" to 'true' to let continuous input stream not being
+ // stuck on reading open frame.
+ new BufferedOutputStream(new ZstdOutputStream(s, level).setCloseFrameOnFlush(true), bufferSize)
+ }
+
override def compressedInputStream(s: InputStream): InputStream = {
// Wrap the zstd input stream in a buffered input stream so that we can
// avoid overhead excessive of JNI call while trying to uncompress small amount of data.
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
index b6be8aaefd351..d98d5e3b81aa0 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
@@ -38,7 +38,7 @@ private[spark] class MetricsConfig(conf: SparkConf) extends Logging {
private[metrics] val properties = new Properties()
private[metrics] var perInstanceSubProperties: mutable.HashMap[String, Properties] = null
- private def setDefaultProperties(prop: Properties) {
+ private def setDefaultProperties(prop: Properties): Unit = {
prop.setProperty("*.sink.servlet.class", "org.apache.spark.metrics.sink.MetricsServlet")
prop.setProperty("*.sink.servlet.path", "/metrics/json")
prop.setProperty("master.sink.servlet.path", "/metrics/master/json")
@@ -49,7 +49,7 @@ private[spark] class MetricsConfig(conf: SparkConf) extends Logging {
* Load properties from various places, based on precedence
* If the same property is set again latter on in the method, it overwrites the previous value
*/
- def initialize() {
+ def initialize(): Unit = {
// Add default properties in case there's no properties file
setDefaultProperties(properties)
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index c96640a6fab3f..ead8fde3e0872 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -28,7 +28,7 @@ import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
-import org.apache.spark.metrics.sink.{MetricsServlet, Sink}
+import org.apache.spark.metrics.sink.{MetricsServlet, PrometheusServlet, Sink}
import org.apache.spark.metrics.source.{Source, StaticSources}
import org.apache.spark.util.Utils
@@ -83,18 +83,20 @@ private[spark] class MetricsSystem private (
// Treat MetricsServlet as a special sink as it should be exposed to add handlers to web ui
private var metricsServlet: Option[MetricsServlet] = None
+ private var prometheusServlet: Option[PrometheusServlet] = None
/**
* Get any UI handlers used by this metrics system; can only be called after start().
*/
def getServletHandlers: Array[ServletContextHandler] = {
require(running, "Can only call getServletHandlers on a running MetricsSystem")
- metricsServlet.map(_.getHandlers(conf)).getOrElse(Array())
+ metricsServlet.map(_.getHandlers(conf)).getOrElse(Array()) ++
+ prometheusServlet.map(_.getHandlers(conf)).getOrElse(Array())
}
metricsConfig.initialize()
- def start(registerStaticSources: Boolean = true) {
+ def start(registerStaticSources: Boolean = true): Unit = {
require(!running, "Attempting to start a MetricsSystem that is already running")
running = true
if (registerStaticSources) {
@@ -105,7 +107,7 @@ private[spark] class MetricsSystem private (
sinks.foreach(_.start)
}
- def stop() {
+ def stop(): Unit = {
if (running) {
sinks.foreach(_.stop)
} else {
@@ -114,7 +116,7 @@ private[spark] class MetricsSystem private (
running = false
}
- def report() {
+ def report(): Unit = {
sinks.foreach(_.report())
}
@@ -124,7 +126,7 @@ private[spark] class MetricsSystem private (
* If either ID is not available, this defaults to just using .
*
* @param source Metric source to be named by this method.
- * @return An unique metric name for each combination of
+ * @return A unique metric name for each combination of
* application, executor/driver and metric source.
*/
private[spark] def buildRegistryName(source: Source): String = {
@@ -155,7 +157,7 @@ private[spark] class MetricsSystem private (
def getSourcesByName(sourceName: String): Seq[Source] =
sources.filter(_.sourceName == sourceName)
- def registerSource(source: Source) {
+ def registerSource(source: Source): Unit = {
sources += source
try {
val regName = buildRegistryName(source)
@@ -165,13 +167,13 @@ private[spark] class MetricsSystem private (
}
}
- def removeSource(source: Source) {
+ def removeSource(source: Source): Unit = {
sources -= source
val regName = buildRegistryName(source)
registry.removeMatching((name: String, _: Metric) => name.startsWith(regName))
}
- private def registerSources() {
+ private def registerSources(): Unit = {
val instConfig = metricsConfig.getInstance(instance)
val sourceConfigs = metricsConfig.subProperties(instConfig, MetricsSystem.SOURCE_REGEX)
@@ -187,7 +189,7 @@ private[spark] class MetricsSystem private (
}
}
- private def registerSinks() {
+ private def registerSinks(): Unit = {
val instConfig = metricsConfig.getInstance(instance)
val sinkConfigs = metricsConfig.subProperties(instConfig, MetricsSystem.SINK_REGEX)
@@ -201,6 +203,12 @@ private[spark] class MetricsSystem private (
classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager])
.newInstance(kv._2, registry, securityMgr)
metricsServlet = Some(servlet)
+ } else if (kv._1 == "prometheusServlet") {
+ val servlet = Utils.classForName[PrometheusServlet](classPath)
+ .getConstructor(
+ classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager])
+ .newInstance(kv._2, registry, securityMgr)
+ prometheusServlet = Some(servlet)
} else {
val sink = Utils.classForName[Sink](classPath)
.getConstructor(
@@ -225,7 +233,7 @@ private[spark] object MetricsSystem {
private[this] val MINIMAL_POLL_UNIT = TimeUnit.SECONDS
private[this] val MINIMAL_POLL_PERIOD = 1
- def checkMinimalPollingPeriod(pollUnit: TimeUnit, pollPeriod: Int) {
+ def checkMinimalPollingPeriod(pollUnit: TimeUnit, pollPeriod: Int): Unit = {
val period = MINIMAL_POLL_UNIT.convert(pollPeriod, pollUnit)
if (period < MINIMAL_POLL_PERIOD) {
throw new IllegalArgumentException("Polling period " + pollPeriod + " " + pollUnit +
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
index fce556fd0382c..bfd23168e4003 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
@@ -50,15 +50,15 @@ private[spark] class ConsoleSink(val property: Properties, val registry: MetricR
.convertRatesTo(TimeUnit.SECONDS)
.build()
- override def start() {
+ override def start(): Unit = {
reporter.start(pollPeriod, pollUnit)
}
- override def stop() {
+ override def stop(): Unit = {
reporter.stop()
}
- override def report() {
+ override def report(): Unit = {
reporter.report()
}
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
index 88bba2fdbd1c6..579b8e0c0e984 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
@@ -59,15 +59,15 @@ private[spark] class CsvSink(val property: Properties, val registry: MetricRegis
.convertRatesTo(TimeUnit.SECONDS)
.build(new File(pollDir))
- override def start() {
+ override def start(): Unit = {
reporter.start(pollPeriod, pollUnit)
}
- override def stop() {
+ override def stop(): Unit = {
reporter.stop()
}
- override def report() {
+ override def report(): Unit = {
reporter.report()
}
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
index 05d553ed30ff0..6ce64cd3543fe 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
@@ -89,15 +89,15 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric
.filter(filter)
.build(graphite)
- override def start() {
+ override def start(): Unit = {
reporter.start(pollPeriod, pollUnit)
}
- override def stop() {
+ override def stop(): Unit = {
reporter.stop()
}
- override def report() {
+ override def report(): Unit = {
reporter.report()
}
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
index 1992b42ac7f6b..9e94a868ccc36 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
@@ -28,14 +28,14 @@ private[spark] class JmxSink(val property: Properties, val registry: MetricRegis
val reporter: JmxReporter = JmxReporter.forRegistry(registry).build()
- override def start() {
+ override def start(): Unit = {
reporter.start()
}
- override def stop() {
+ override def stop(): Unit = {
reporter.stop()
}
- override def report() { }
+ override def report(): Unit = { }
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
index bea24ca7807e4..7dd27d4fb9bf3 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
@@ -59,9 +59,9 @@ private[spark] class MetricsServlet(
mapper.writeValueAsString(registry)
}
- override def start() { }
+ override def start(): Unit = { }
- override def stop() { }
+ override def stop(): Unit = { }
- override def report() { }
+ override def report(): Unit = { }
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala
new file mode 100644
index 0000000000000..7c33bce78378d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/PrometheusServlet.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.sink
+
+import java.util.Properties
+import javax.servlet.http.HttpServletRequest
+
+import com.codahale.metrics.MetricRegistry
+import org.eclipse.jetty.servlet.ServletContextHandler
+
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.ui.JettyUtils._
+
+/**
+ * This exposes the metrics of the given registry with Prometheus format.
+ *
+ * The output is consistent with /metrics/json result in terms of item ordering
+ * and with the previous result of Spark JMX Sink + Prometheus JMX Converter combination
+ * in terms of key string format.
+ */
+private[spark] class PrometheusServlet(
+ val property: Properties,
+ val registry: MetricRegistry,
+ securityMgr: SecurityManager)
+ extends Sink {
+
+ val SERVLET_KEY_PATH = "path"
+
+ val servletPath = property.getProperty(SERVLET_KEY_PATH)
+
+ def getHandlers(conf: SparkConf): Array[ServletContextHandler] = {
+ Array[ServletContextHandler](
+ createServletHandler(servletPath,
+ new ServletParams(request => getMetricsSnapshot(request), "text/plain"), conf)
+ )
+ }
+
+ def getMetricsSnapshot(request: HttpServletRequest): String = {
+ import scala.collection.JavaConverters._
+
+ val sb = new StringBuilder()
+ registry.getGauges.asScala.foreach { case (k, v) =>
+ if (!v.getValue.isInstanceOf[String]) {
+ sb.append(s"${normalizeKey(k)}Value ${v.getValue}\n")
+ }
+ }
+ registry.getCounters.asScala.foreach { case (k, v) =>
+ sb.append(s"${normalizeKey(k)}Count ${v.getCount}\n")
+ }
+ registry.getHistograms.asScala.foreach { case (k, h) =>
+ val snapshot = h.getSnapshot
+ val prefix = normalizeKey(k)
+ sb.append(s"${prefix}Count ${h.getCount}\n")
+ sb.append(s"${prefix}Max ${snapshot.getMax}\n")
+ sb.append(s"${prefix}Mean ${snapshot.getMean}\n")
+ sb.append(s"${prefix}Min ${snapshot.getMin}\n")
+ sb.append(s"${prefix}50thPercentile ${snapshot.getMedian}\n")
+ sb.append(s"${prefix}75thPercentile ${snapshot.get75thPercentile}\n")
+ sb.append(s"${prefix}95thPercentile ${snapshot.get95thPercentile}\n")
+ sb.append(s"${prefix}98thPercentile ${snapshot.get98thPercentile}\n")
+ sb.append(s"${prefix}99thPercentile ${snapshot.get99thPercentile}\n")
+ sb.append(s"${prefix}999thPercentile ${snapshot.get999thPercentile}\n")
+ sb.append(s"${prefix}StdDev ${snapshot.getStdDev}\n")
+ }
+ registry.getMeters.entrySet.iterator.asScala.foreach { kv =>
+ val prefix = normalizeKey(kv.getKey)
+ val meter = kv.getValue
+ sb.append(s"${prefix}Count ${meter.getCount}\n")
+ sb.append(s"${prefix}MeanRate ${meter.getMeanRate}\n")
+ sb.append(s"${prefix}OneMinuteRate ${meter.getOneMinuteRate}\n")
+ sb.append(s"${prefix}FiveMinuteRate ${meter.getFiveMinuteRate}\n")
+ sb.append(s"${prefix}FifteenMinuteRate ${meter.getFifteenMinuteRate}\n")
+ }
+ registry.getTimers.entrySet.iterator.asScala.foreach { kv =>
+ val prefix = normalizeKey(kv.getKey)
+ val timer = kv.getValue
+ val snapshot = timer.getSnapshot
+ sb.append(s"${prefix}Count ${timer.getCount}\n")
+ sb.append(s"${prefix}Max ${snapshot.getMax}\n")
+ sb.append(s"${prefix}Mean ${snapshot.getMax}\n")
+ sb.append(s"${prefix}Min ${snapshot.getMin}\n")
+ sb.append(s"${prefix}50thPercentile ${snapshot.getMedian}\n")
+ sb.append(s"${prefix}75thPercentile ${snapshot.get75thPercentile}\n")
+ sb.append(s"${prefix}95thPercentile ${snapshot.get95thPercentile}\n")
+ sb.append(s"${prefix}98thPercentile ${snapshot.get98thPercentile}\n")
+ sb.append(s"${prefix}99thPercentile ${snapshot.get99thPercentile}\n")
+ sb.append(s"${prefix}999thPercentile ${snapshot.get999thPercentile}\n")
+ sb.append(s"${prefix}StdDev ${snapshot.getStdDev}\n")
+ sb.append(s"${prefix}FifteenMinuteRate ${timer.getFifteenMinuteRate}\n")
+ sb.append(s"${prefix}FiveMinuteRate ${timer.getFiveMinuteRate}\n")
+ sb.append(s"${prefix}OneMinuteRate ${timer.getOneMinuteRate}\n")
+ sb.append(s"${prefix}MeanRate ${timer.getMeanRate}\n")
+ }
+ sb.toString()
+ }
+
+ private def normalizeKey(key: String): String = {
+ s"metrics_${key.replaceAll("[^a-zA-Z0-9]", "_")}_"
+ }
+
+ override def start(): Unit = { }
+
+ override def stop(): Unit = { }
+
+ override def report(): Unit = { }
+}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala
index 7fa4ba7622980..968d5ca809e72 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala
@@ -53,15 +53,15 @@ private[spark] class Slf4jSink(
.convertRatesTo(TimeUnit.SECONDS)
.build()
- override def start() {
+ override def start(): Unit = {
reporter.start(pollPeriod, pollUnit)
}
- override def stop() {
+ override def stop(): Unit = {
reporter.stop()
}
- override def report() {
+ override def report(): Unit = {
reporter.report()
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 1d27fe7db193f..ffb696029a033 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -116,7 +116,8 @@ private[spark] class NettyBlockTransferService(
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
- override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
+ override def createAndStart(blockIds: Array[String],
+ listener: BlockFetchingListener): Unit = {
try {
val client = clientFactory.createClient(host, port)
new OneForOneBlockFetcher(client, appId, execId, blockIds, listener,
diff --git a/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala
index b089bbd7e972e..34c04f4025a96 100644
--- a/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala
+++ b/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala
@@ -43,7 +43,7 @@ private[spark] class ApproximateActionListener[T, U, R](
var failure: Option[Exception] = None // Set if the job has failed (permanently)
var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult
- override def taskSucceeded(index: Int, result: Any) {
+ override def taskSucceeded(index: Int, result: Any): Unit = {
synchronized {
evaluator.merge(index, result.asInstanceOf[U])
finishedTasks += 1
@@ -56,7 +56,7 @@ private[spark] class ApproximateActionListener[T, U, R](
}
}
- override def jobFailed(exception: Exception) {
+ override def jobFailed(exception: Exception): Unit = {
synchronized {
failure = Some(exception)
this.notifyAll()
diff --git a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala
index 25cb7490aa9c9..012d4769617f6 100644
--- a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala
+++ b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala
@@ -61,7 +61,7 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) {
* Set a handler to be called if this PartialResult's job fails. Only one failure handler
* is supported per PartialResult.
*/
- def onFail(handler: Exception => Unit) {
+ def onFail(handler: Exception => Unit): Unit = {
synchronized {
if (failureHandler.isDefined) {
throw new UnsupportedOperationException("onFail cannot be called twice")
@@ -85,7 +85,7 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) {
override def onComplete(handler: T => Unit): PartialResult[T] = synchronized {
PartialResult.this.onComplete(handler.compose(f)).map(f)
}
- override def onFail(handler: Exception => Unit) {
+ override def onFail(handler: Exception => Unit): Unit = {
synchronized {
PartialResult.this.onFail(handler)
}
@@ -100,7 +100,7 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) {
}
}
- private[spark] def setFinalValue(value: R) {
+ private[spark] def setFinalValue(value: R): Unit = {
synchronized {
if (finalValue.isDefined) {
throw new UnsupportedOperationException("setFinalValue called twice on a PartialResult")
@@ -115,7 +115,7 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) {
private def getFinalValueInternal() = finalValue
- private[spark] def setFailure(exception: Exception) {
+ private[spark] def setFailure(exception: Exception): Unit = {
synchronized {
if (failure.isDefined) {
throw new UnsupportedOperationException("setFailure called twice on a PartialResult")
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index 23cf19d55b4ae..a5c3e2a2dfe2a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -61,7 +61,7 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo
* irreversible operation, as the data in the blocks cannot be recovered back
* once removed. Use it with caution.
*/
- private[spark] def removeBlocks() {
+ private[spark] def removeBlocks(): Unit = {
blockIds.foreach { blockId =>
sparkContext.env.blockManager.master.removeBlock(blockId)
}
@@ -77,7 +77,7 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo
}
/** Check if this BlockRDD is valid. If not valid, exception is thrown. */
- private[spark] def assertValid() {
+ private[spark] def assertValid(): Unit = {
if (!isValid) {
throw new SparkException(
"Attempted to use %s after its blocks have been removed!".format(toString))
diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
index 57108dcedcf0c..fddd35b657479 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
@@ -85,7 +85,7 @@ class CartesianRDD[T: ClassTag, U: ClassTag](
}
)
- override def clearDependencies() {
+ override def clearDependencies(): Unit = {
super.clearDependencies()
rdd1 = null
rdd2 = null
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 909f58512153b..500d306f336ac 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -187,7 +187,7 @@ class CoGroupedRDD[K: ClassTag](
createCombiner, mergeValue, mergeCombiners)
}
- override def clearDependencies() {
+ override def clearDependencies(): Unit = {
super.clearDependencies()
rdds = null
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index 55c141c2b8a0a..58a0c0c400e09 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -107,7 +107,7 @@ private[spark] class CoalescedRDD[T: ClassTag](
})
}
- override def clearDependencies() {
+ override def clearDependencies(): Unit = {
super.clearDependencies()
prev = null
}
@@ -239,7 +239,7 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10)
* locations (2 * n log(n))
* @param targetLen The number of desired partition groups
*/
- def setupGroups(targetLen: Int, partitionLocs: PartitionLocations) {
+ def setupGroups(targetLen: Int, partitionLocs: PartitionLocations): Unit = {
// deal with empty case, just create targetLen partition groups with no preferred location
if (partitionLocs.partsWithLocs.isEmpty) {
(1 to targetLen).foreach(_ => groupArr += new PartitionGroup())
@@ -328,7 +328,7 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10)
def throwBalls(
maxPartitions: Int,
prev: RDD[_],
- balanceSlack: Double, partitionLocs: PartitionLocations) {
+ balanceSlack: Double, partitionLocs: PartitionLocations): Unit = {
if (noLocality) { // no preferredLocations in parent RDD, no randomization needed
if (maxPartitions > groupArr.size) { // just return prev.partitions
for ((p, i) <- prev.partitions.zipWithIndex) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index f3f9be3562922..ff4928dae6bf8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -375,7 +375,7 @@ class HadoopRDD[K, V](
locs.getOrElse(hsplit.getLocations.filter(_ != "localhost"))
}
- override def checkpoint() {
+ override def checkpoint(): Unit = {
// Do nothing. Hadoop RDD should not be checkpointed.
}
@@ -412,7 +412,7 @@ private[spark] object HadoopRDD extends Logging {
/** Add Hadoop configuration specific to a single partition and attempt. */
def addLocalConfiguration(jobTrackerId: String, jobId: Int, splitId: Int, attemptId: Int,
- conf: JobConf) {
+ conf: JobConf): Unit = {
val jobID = new JobID(jobTrackerId, jobId)
val taId = new TaskAttemptID(new TaskID(jobID, TaskType.MAP, splitId), attemptId)
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index 56ef3e107a980..fccabcdd169c6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -109,7 +109,7 @@ class JdbcRDD[T: ClassTag](
}
}
- override def close() {
+ override def close(): Unit = {
try {
if (null != rs) {
rs.close()
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
index aa61997122cf4..39520a9734b06 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -51,7 +51,7 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
override def compute(split: Partition, context: TaskContext): Iterator[U] =
f(context, split.index, firstParent[T].iterator(split, context))
- override def clearDependencies() {
+ override def clearDependencies(): Unit = {
super.clearDependencies()
prev = null
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index e23133682360f..1e39e10856877 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -261,7 +261,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
} else {
StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, false, seed)
}
- self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
+ self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true, isOrderSensitive = true)
}
/**
@@ -291,7 +291,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
} else {
StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, true, seed)
}
- self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
+ self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true, isOrderSensitive = true)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
index d744d67592545..965618ee827d1 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
@@ -101,7 +101,7 @@ class PartitionerAwareUnionRDD[T: ClassTag](
}
}
- override def clearDependencies() {
+ override def clearDependencies(): Unit = {
super.clearDependencies()
rdds = null
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
index 15691a8fc8eaa..c8cdaa60e4335 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
@@ -67,4 +67,12 @@ private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
thisSampler.setSeed(split.seed)
thisSampler.sample(firstParent[T].iterator(split.prev, context))
}
+
+ override protected def getOutputDeterministicLevel = {
+ if (prev.outputDeterministicLevel == DeterministicLevel.UNORDERED) {
+ DeterministicLevel.INDETERMINATE
+ } else {
+ super.getOutputDeterministicLevel
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index eafe3b17c2136..08fc309d5238e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -430,8 +430,6 @@ abstract class RDD[T: ClassTag](
*
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
* which can avoid performing a shuffle.
- *
- * TODO Fix the Shuffle+Repartition data loss issue described in SPARK-23207.
*/
def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = withScope {
coalesce(numPartitions, shuffle = true)
@@ -557,7 +555,7 @@ abstract class RDD[T: ClassTag](
val sampler = new BernoulliCellSampler[T](lb, ub)
sampler.setSeed(seed + index)
sampler.sample(partition)
- }, preservesPartitioning = true)
+ }, isOrderSensitive = true, preservesPartitioning = true)
}
/**
@@ -870,6 +868,29 @@ abstract class RDD[T: ClassTag](
preservesPartitioning)
}
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD, while tracking the index
+ * of the original partition.
+ *
+ * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+ * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
+ *
+ * `isOrderSensitive` indicates whether the function is order-sensitive. If it is order
+ * sensitive, it may return totally different result when the input order
+ * is changed. Mostly stateful functions are order-sensitive.
+ */
+ private[spark] def mapPartitionsWithIndex[U: ClassTag](
+ f: (Int, Iterator[T]) => Iterator[U],
+ preservesPartitioning: Boolean,
+ isOrderSensitive: Boolean): RDD[U] = withScope {
+ val cleanedF = sc.clean(f)
+ new MapPartitionsRDD(
+ this,
+ (_: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter),
+ preservesPartitioning,
+ isOrderSensitive = isOrderSensitive)
+ }
+
/**
* Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
* second element in each RDD, etc. Assumes that the two RDDs have the *same number of
diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
index d165610291f1d..2caf9761b4432 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
@@ -166,7 +166,7 @@ private[spark] object ReliableCheckpointRDD extends Logging {
def writePartitionToCheckpointFile[T: ClassTag](
path: String,
broadcastedConf: Broadcast[SerializableConfiguration],
- blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
+ blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]): Unit = {
val env = SparkEnv.get
val outputDir = new Path(path)
val fs = outputDir.getFileSystem(broadcastedConf.value.value)
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index 5ec99b7f4f3ab..0930a5c9cfb96 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -108,7 +108,7 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
.asInstanceOf[Iterator[(K, C)]]
}
- override def clearDependencies() {
+ override def clearDependencies(): Unit = {
super.clearDependencies()
prev = null
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 42d190377f104..d5a811d4dc3fd 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -127,7 +127,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
map.asScala.iterator.map(t => t._2.iterator.map((t._1, _))).flatten
}
- override def clearDependencies() {
+ override def clearDependencies(): Unit = {
super.clearDependencies()
rdd1 = null
rdd2 = null
diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index 36589e93a1c5e..63fa3c2487c33 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -21,6 +21,7 @@ import java.io.{IOException, ObjectOutputStream}
import scala.collection.mutable.ArrayBuffer
import scala.collection.parallel.ForkJoinTaskSupport
+import scala.collection.parallel.immutable.ParVector
import scala.reflect.ClassTag
import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
@@ -75,13 +76,13 @@ class UnionRDD[T: ClassTag](
override def getPartitions: Array[Partition] = {
val parRDDs = if (isPartitionListingParallel) {
- val parArray = rdds.par
+ val parArray = new ParVector(rdds.toVector)
parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
parArray
} else {
rdds
}
- val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum)
+ val array = new Array[Partition](parRDDs.map(_.partitions.length).sum)
var pos = 0
for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
@@ -108,7 +109,7 @@ class UnionRDD[T: ClassTag](
override def getPreferredLocations(s: Partition): Seq[String] =
s.asInstanceOf[UnionPartition[T]].preferredLocations()
- override def clearDependencies() {
+ override def clearDependencies(): Unit = {
super.clearDependencies()
rdds = null
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
index 3cb1231bd3477..678a48948a3c1 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
@@ -70,7 +70,7 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag](
s.asInstanceOf[ZippedPartitionsPartition].preferredLocations
}
- override def clearDependencies() {
+ override def clearDependencies(): Unit = {
super.clearDependencies()
rdds = null
}
@@ -89,7 +89,7 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]
f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context))
}
- override def clearDependencies() {
+ override def clearDependencies(): Unit = {
super.clearDependencies()
rdd1 = null
rdd2 = null
@@ -114,7 +114,7 @@ private[spark] class ZippedPartitionsRDD3
rdd3.iterator(partitions(2), context))
}
- override def clearDependencies() {
+ override def clearDependencies(): Unit = {
super.clearDependencies()
rdd1 = null
rdd2 = null
@@ -142,7 +142,7 @@ private[spark] class ZippedPartitionsRDD4
rdd4.iterator(partitions(3), context))
}
- override def clearDependencies() {
+ override def clearDependencies(): Unit = {
super.clearDependencies()
rdd1 = null
rdd2 = null
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 9df59459ca799..81e0543ccefef 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -29,8 +29,6 @@ import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
import scala.concurrent.duration._
import scala.util.control.NonFatal
-import org.apache.commons.lang3.SerializationUtils
-
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
@@ -229,7 +227,7 @@ private[spark] class DAGScheduler(
/**
* Called by the TaskSetManager to report task's starting.
*/
- def taskStarted(task: Task[_], taskInfo: TaskInfo) {
+ def taskStarted(task: Task[_], taskInfo: TaskInfo): Unit = {
eventProcessLoop.post(BeginEvent(task, taskInfo))
}
@@ -237,7 +235,7 @@ private[spark] class DAGScheduler(
* Called by the TaskSetManager to report that a task has completed
* and results are being fetched remotely.
*/
- def taskGettingResult(taskInfo: TaskInfo) {
+ def taskGettingResult(taskInfo: TaskInfo): Unit = {
eventProcessLoop.post(GettingResultEvent(taskInfo))
}
@@ -560,7 +558,7 @@ private[spark] class DAGScheduler(
// caused by recursively visiting
val waitingForVisit = new ListBuffer[RDD[_]]
waitingForVisit += stage.rdd
- def visit(rdd: RDD[_]) {
+ def visit(rdd: RDD[_]): Unit = {
if (!visited(rdd)) {
visited += rdd
val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil)
@@ -591,7 +589,7 @@ private[spark] class DAGScheduler(
*/
private def updateJobIdStageIdMaps(jobId: Int, stage: Stage): Unit = {
@tailrec
- def updateJobIdStageIdMapsList(stages: List[Stage]) {
+ def updateJobIdStageIdMapsList(stages: List[Stage]): Unit = {
if (stages.nonEmpty) {
val s = stages.head
s.jobIds += jobId
@@ -622,7 +620,7 @@ private[spark] class DAGScheduler(
"Job %d not registered for stage %d even though that stage was registered for the job"
.format(job.jobId, stageId))
} else {
- def removeStage(stageId: Int) {
+ def removeStage(stageId: Int): Unit = {
// data structures based on Stage
for (stage <- stageIdToStage.get(stageId)) {
if (runningStages.contains(stage)) {
@@ -698,7 +696,7 @@ private[spark] class DAGScheduler(
if (partitions.isEmpty) {
val time = clock.getTimeMillis()
listenerBus.post(
- SparkListenerJobStart(jobId, time, Seq[StageInfo](), SerializationUtils.clone(properties)))
+ SparkListenerJobStart(jobId, time, Seq[StageInfo](), Utils.cloneProperties(properties)))
listenerBus.post(
SparkListenerJobEnd(jobId, time, JobSucceeded))
// Return immediately if the job is running 0 tasks
@@ -710,7 +708,7 @@ private[spark] class DAGScheduler(
val waiter = new JobWaiter[U](this, jobId, partitions.size, resultHandler)
eventProcessLoop.post(JobSubmitted(
jobId, rdd, func2, partitions.toArray, callSite, waiter,
- SerializationUtils.clone(properties)))
+ Utils.cloneProperties(properties)))
waiter
}
@@ -782,7 +780,7 @@ private[spark] class DAGScheduler(
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
eventProcessLoop.post(JobSubmitted(
jobId, rdd, func2, rdd.partitions.indices.toArray, callSite, listener,
- SerializationUtils.clone(properties)))
+ Utils.cloneProperties(properties)))
listener.awaitResult() // Will throw an exception if the job fails
}
@@ -819,7 +817,7 @@ private[spark] class DAGScheduler(
this, jobId, 1,
(_: Int, r: MapOutputStatistics) => callback(r))
eventProcessLoop.post(MapStageSubmitted(
- jobId, dependency, callSite, waiter, SerializationUtils.clone(properties)))
+ jobId, dependency, callSite, waiter, Utils.cloneProperties(properties)))
waiter
}
@@ -846,7 +844,7 @@ private[spark] class DAGScheduler(
eventProcessLoop.post(AllJobsCancelled)
}
- private[scheduler] def doCancelAllJobs() {
+ private[scheduler] def doCancelAllJobs(): Unit = {
// Cancel all running jobs.
runningStages.map(_.firstJobId).foreach(handleJobCancellation(_,
Option("as part of cancellation of all jobs")))
@@ -857,7 +855,7 @@ private[spark] class DAGScheduler(
/**
* Cancel all jobs associated with a running or scheduled stage.
*/
- def cancelStage(stageId: Int, reason: Option[String]) {
+ def cancelStage(stageId: Int, reason: Option[String]): Unit = {
eventProcessLoop.post(StageCancelled(stageId, reason))
}
@@ -874,7 +872,7 @@ private[spark] class DAGScheduler(
* Resubmit any failed stages. Ordinarily called after a small amount of time has passed since
* the last fetch failure.
*/
- private[scheduler] def resubmitFailedStages() {
+ private[scheduler] def resubmitFailedStages(): Unit = {
if (failedStages.nonEmpty) {
// Failed stages may be removed by job cancellation, so failed might be empty even if
// the ResubmitFailedStages event has been scheduled.
@@ -893,7 +891,7 @@ private[spark] class DAGScheduler(
* Submits stages that depend on the given parent stage. Called when the parent stage completes
* successfully.
*/
- private def submitWaitingChildStages(parent: Stage) {
+ private def submitWaitingChildStages(parent: Stage): Unit = {
logTrace(s"Checking if any dependencies of $parent are now runnable")
logTrace("running: " + runningStages)
logTrace("waiting: " + waitingStages)
@@ -915,7 +913,7 @@ private[spark] class DAGScheduler(
jobsThatUseStage.find(jobIdToActiveJob.contains)
}
- private[scheduler] def handleJobGroupCancelled(groupId: String) {
+ private[scheduler] def handleJobGroupCancelled(groupId: String): Unit = {
// Cancel all jobs belonging to this job group.
// First finds all active jobs with this group id, and then kill stages for them.
val activeInGroup = activeJobs.filter { activeJob =>
@@ -928,7 +926,7 @@ private[spark] class DAGScheduler(
Option("part of cancelled job group %s".format(groupId))))
}
- private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) {
+ private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = {
// Note that there is a chance that this task is launched after the stage is cancelled.
// In that case, we wouldn't have the stage anymore in stageIdToStage.
val stageAttemptId =
@@ -947,7 +945,7 @@ private[spark] class DAGScheduler(
stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) }
}
- private[scheduler] def cleanUpAfterSchedulerStop() {
+ private[scheduler] def cleanUpAfterSchedulerStop(): Unit = {
for (job <- activeJobs) {
val error =
new SparkException(s"Job ${job.jobId} cancelled because SparkContext was shut down")
@@ -965,7 +963,7 @@ private[spark] class DAGScheduler(
}
}
- private[scheduler] def handleGetTaskResult(taskInfo: TaskInfo) {
+ private[scheduler] def handleGetTaskResult(taskInfo: TaskInfo): Unit = {
listenerBus.post(SparkListenerTaskGettingResult(taskInfo))
}
@@ -975,7 +973,7 @@ private[spark] class DAGScheduler(
partitions: Array[Int],
callSite: CallSite,
listener: JobListener,
- properties: Properties) {
+ properties: Properties): Unit = {
var finalStage: ResultStage = null
try {
// New stage creation may throw an exception if, for example, jobs are run on a
@@ -1039,7 +1037,7 @@ private[spark] class DAGScheduler(
dependency: ShuffleDependency[_, _, _],
callSite: CallSite,
listener: JobListener,
- properties: Properties) {
+ properties: Properties): Unit = {
// Submitting this map stage might still require the creation of some parent stages, so make
// sure that happens.
var finalStage: ShuffleMapStage = null
@@ -1079,7 +1077,7 @@ private[spark] class DAGScheduler(
}
/** Submits stage, but first recursively submits any missing parents. */
- private def submitStage(stage: Stage) {
+ private def submitStage(stage: Stage): Unit = {
val jobId = activeJobForStage(stage)
if (jobId.isDefined) {
logDebug("submitStage(" + stage + ")")
@@ -1102,10 +1100,19 @@ private[spark] class DAGScheduler(
}
/** Called when stage's parents are available and we can now do its task. */
- private def submitMissingTasks(stage: Stage, jobId: Int) {
+ private def submitMissingTasks(stage: Stage, jobId: Int): Unit = {
logDebug("submitMissingTasks(" + stage + ")")
- // First figure out the indexes of partition ids to compute.
+ // Before find missing partition, do the intermediate state clean work first.
+ // The operation here can make sure for the partially completed intermediate stage,
+ // `findMissingPartitions()` returns all partitions every time.
+ stage match {
+ case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable =>
+ mapOutputTracker.unregisterAllMapOutput(sms.shuffleDep.shuffleId)
+ case _ =>
+ }
+
+ // Figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()
// Use the scheduling pool, job group, description, etc. from an ActiveJob associated
@@ -1346,7 +1353,7 @@ private[spark] class DAGScheduler(
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
*/
- private[scheduler] def handleTaskCompletion(event: CompletionEvent) {
+ private[scheduler] def handleTaskCompletion(event: CompletionEvent): Unit = {
val task = event.task
val stageId = task.stageId
@@ -1500,7 +1507,7 @@ private[spark] class DAGScheduler(
}
}
- case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) =>
+ case FetchFailed(bmAddress, shuffleId, _, mapIndex, _, failureMessage) =>
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleIdToMapStage(shuffleId)
@@ -1523,17 +1530,17 @@ private[spark] class DAGScheduler(
markStageAsFinished(failedStage, errorMessage = Some(failureMessage),
willRetry = !shouldAbortStage)
} else {
- logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " +
- s"longer running")
+ logDebug(s"Received fetch failure from $task, but it's from $failedStage which is no " +
+ "longer running")
}
if (mapStage.rdd.isBarrier()) {
// Mark all the map as broken in the map stage, to ensure retry all the tasks on
// resubmitted stage attempt.
mapOutputTracker.unregisterAllMapOutput(shuffleId)
- } else if (mapId != -1) {
+ } else if (mapIndex != -1) {
// Mark the map whose fetch failed as broken in the map stage
- mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
+ mapOutputTracker.unregisterMapOutput(shuffleId, mapIndex, bmAddress)
}
if (failedStage.rdd.isBarrier()) {
@@ -1575,7 +1582,7 @@ private[spark] class DAGScheduler(
// Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is
// guaranteed to be determinate, so the input data of the reducers will not change
// even if the map tasks are re-tried.
- if (mapStage.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) {
+ if (mapStage.isIndeterminate) {
// It's a little tricky to find all the succeeding stages of `mapStage`, because
// each stage only know its parents not children. Here we traverse the stages from
// the leaf nodes (the result stages of active jobs), and rollback all the stages
@@ -1603,15 +1610,22 @@ private[spark] class DAGScheduler(
activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil))
+ // The stages will be rolled back after checking
+ val rollingBackStages = HashSet[Stage](mapStage)
stagesToRollback.foreach {
case mapStage: ShuffleMapStage =>
val numMissingPartitions = mapStage.findMissingPartitions().length
if (numMissingPartitions < mapStage.numTasks) {
- // TODO: support to rollback shuffle files.
- // Currently the shuffle writing is "first write wins", so we can't re-run a
- // shuffle map stage and overwrite existing shuffle files. We have to finish
- // SPARK-8029 first.
- abortStage(mapStage, generateErrorMessage(mapStage), None)
+ if (sc.getConf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) {
+ val reason = "A shuffle map stage with indeterminate output was failed " +
+ "and retried. However, Spark can only do this while using the new " +
+ "shuffle block fetching protocol. Please check the config " +
+ "'spark.shuffle.useOldFetchProtocol', see more detail in " +
+ "SPARK-27665 and SPARK-25341."
+ abortStage(mapStage, reason, None)
+ } else {
+ rollingBackStages += mapStage
+ }
}
case resultStage: ResultStage if resultStage.activeJob.isDefined =>
@@ -1623,6 +1637,9 @@ private[spark] class DAGScheduler(
case _ =>
}
+ logInfo(s"The shuffle map stage $mapStage with indeterminate output was failed, " +
+ s"we will roll back and rerun below stages which include itself and all its " +
+ s"indeterminate child stages: $rollingBackStages")
}
// We expect one executor failure to trigger many FetchFailures in rapid succession,
@@ -1862,7 +1879,7 @@ private[spark] class DAGScheduler(
clearCacheLocs()
}
- private[scheduler] def handleExecutorAdded(execId: String, host: String) {
+ private[scheduler] def handleExecutorAdded(execId: String, host: String): Unit = {
// remove from failedEpoch(execId) ?
if (failedEpoch.contains(execId)) {
logInfo("Host added was in lost list earlier: " + host)
@@ -1870,7 +1887,7 @@ private[spark] class DAGScheduler(
}
}
- private[scheduler] def handleStageCancellation(stageId: Int, reason: Option[String]) {
+ private[scheduler] def handleStageCancellation(stageId: Int, reason: Option[String]): Unit = {
stageIdToStage.get(stageId) match {
case Some(stage) =>
val jobsThatUseStage: Array[Int] = stage.jobIds.toArray
@@ -1888,7 +1905,7 @@ private[spark] class DAGScheduler(
}
}
- private[scheduler] def handleJobCancellation(jobId: Int, reason: Option[String]) {
+ private[scheduler] def handleJobCancellation(jobId: Int, reason: Option[String]): Unit = {
if (!jobIdToStageIds.contains(jobId)) {
logDebug("Trying to cancel unregistered job " + jobId)
} else {
@@ -2010,7 +2027,7 @@ private[spark] class DAGScheduler(
// caused by recursively visiting
val waitingForVisit = new ListBuffer[RDD[_]]
waitingForVisit += stage.rdd
- def visit(rdd: RDD[_]) {
+ def visit(rdd: RDD[_]): Unit = {
if (!visitedRdds(rdd)) {
visitedRdds += rdd
for (dep <- rdd.dependencies) {
@@ -2103,7 +2120,7 @@ private[spark] class DAGScheduler(
listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded))
}
- def stop() {
+ def stop(): Unit = {
messageScheduler.shutdownNow()
eventProcessLoop.stop()
taskScheduler.stop()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index 48eb2da3015f8..a0a4428dc7f55 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -67,7 +67,6 @@ private[spark] class EventLoggingListener(
private val shouldCompress = sparkConf.get(EVENT_LOG_COMPRESS)
private val shouldOverwrite = sparkConf.get(EVENT_LOG_OVERWRITE)
private val shouldLogBlockUpdates = sparkConf.get(EVENT_LOG_BLOCK_UPDATES)
- private val shouldAllowECLogs = sparkConf.get(EVENT_LOG_ALLOW_EC)
private val shouldLogStageExecutorMetrics = sparkConf.get(EVENT_LOG_STAGE_EXECUTOR_METRICS)
private val testing = sparkConf.get(EVENT_LOG_TESTING)
private val outputBufferSize = sparkConf.get(EVENT_LOG_OUTPUT_BUFFER_SIZE).toInt
@@ -100,7 +99,7 @@ private[spark] class EventLoggingListener(
/**
* Creates the log file in the configured log directory.
*/
- def start() {
+ def start(): Unit = {
if (!fileSystem.getFileStatus(new Path(logBaseDir)).isDirectory) {
throw new IllegalArgumentException(s"Log directory $logBaseDir is not a directory.")
}
@@ -121,21 +120,19 @@ private[spark] class EventLoggingListener(
if ((isDefaultLocal && uri.getScheme == null) || uri.getScheme == "file") {
new FileOutputStream(uri.getPath)
} else {
- hadoopDataStream = Some(if (shouldAllowECLogs) {
- fileSystem.create(path)
- } else {
- SparkHadoopUtil.createNonECFile(fileSystem, path)
- })
+ hadoopDataStream = Some(
+ SparkHadoopUtil.createFile(fileSystem, path, sparkConf.get(EVENT_LOG_ALLOW_EC)))
hadoopDataStream.get
}
try {
- val cstream = compressionCodec.map(_.compressedOutputStream(dstream)).getOrElse(dstream)
+ val cstream = compressionCodec.map(_.compressedContinuousOutputStream(dstream))
+ .getOrElse(dstream)
val bstream = new BufferedOutputStream(cstream, outputBufferSize)
EventLoggingListener.initEventLog(bstream, testing, loggedEvents)
fileSystem.setPermission(path, LOG_FILE_PERMISSIONS)
- writer = Some(new PrintWriter(bstream))
+ writer = Some(new PrintWriter(new OutputStreamWriter(bstream, StandardCharsets.UTF_8)))
logInfo("Logging events to %s".format(logPath))
} catch {
case e: Exception =>
@@ -145,7 +142,7 @@ private[spark] class EventLoggingListener(
}
/** Log the event as JSON. */
- private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false) {
+ private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false): Unit = {
val eventJson = JsonProtocol.sparkEventToJson(event)
// scalastyle:off println
writer.foreach(_.println(compact(render(eventJson))))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
index 66ab9a52b7781..2d26a314e7a62 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
@@ -64,7 +64,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
case _ => false
}
- private def validate() {
+ private def validate(): Unit = {
logDebug("validate InputFormatInfo : " + inputFormatClazz + ", path " + path)
try {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
index 65d7184231e24..feed831620840 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
@@ -49,7 +49,7 @@ private[spark] class JobWaiter[T](
* asynchronously. After the low level scheduler cancels all the tasks belonging to this job, it
* will fail this job with a SparkException.
*/
- def cancel() {
+ def cancel(): Unit = {
dagScheduler.cancelJob(jobId, None)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
index 302ebd30da228..bbbddd86cad39 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
@@ -186,6 +186,17 @@ private[spark] class LiveListenerBus(conf: SparkConf) {
metricsSystem.registerSource(metrics)
}
+ /**
+ * For testing only. Wait until there are no more events in the queue, or until the default
+ * wait time has elapsed. Throw `TimeoutException` if the specified time elapsed before the queue
+ * emptied.
+ * Exposed for testing.
+ */
+ @throws(classOf[TimeoutException])
+ private[spark] def waitUntilEmpty(): Unit = {
+ waitUntilEmpty(TimeUnit.SECONDS.toMillis(10))
+ }
+
/**
* For testing only. Wait until there are no more events in the queue, or until the specified
* time has elapsed. Throw `TimeoutException` if the specified time elapsed before the queue
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index 64f0a060a247c..c9d37c985d211 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -43,6 +43,11 @@ private[spark] sealed trait MapStatus {
* necessary for correctness, since block fetchers are allowed to skip zero-size blocks.
*/
def getSizeForBlock(reduceId: Int): Long
+
+ /**
+ * The unique ID of this shuffle map task, we use taskContext.taskAttemptId to fill this.
+ */
+ def mapTaskId: Long
}
@@ -56,11 +61,14 @@ private[spark] object MapStatus {
.map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS))
.getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)
- def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = {
+ def apply(
+ loc: BlockManagerId,
+ uncompressedSizes: Array[Long],
+ mapTaskId: Long): MapStatus = {
if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) {
- HighlyCompressedMapStatus(loc, uncompressedSizes)
+ HighlyCompressedMapStatus(loc, uncompressedSizes, mapTaskId)
} else {
- new CompressedMapStatus(loc, uncompressedSizes)
+ new CompressedMapStatus(loc, uncompressedSizes, mapTaskId)
}
}
@@ -100,16 +108,19 @@ private[spark] object MapStatus {
*
* @param loc location where the task is being executed.
* @param compressedSizes size of the blocks, indexed by reduce partition id.
+ * @param _mapTaskId unique task id for the task
*/
private[spark] class CompressedMapStatus(
private[this] var loc: BlockManagerId,
- private[this] var compressedSizes: Array[Byte])
+ private[this] var compressedSizes: Array[Byte],
+ private[this] var _mapTaskId: Long)
extends MapStatus with Externalizable {
- protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only
+ // For deserialization only
+ protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1)
- def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) {
- this(loc, uncompressedSizes.map(MapStatus.compressSize))
+ def this(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskId: Long) {
+ this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId)
}
override def location: BlockManagerId = loc
@@ -118,10 +129,13 @@ private[spark] class CompressedMapStatus(
MapStatus.decompressSize(compressedSizes(reduceId))
}
+ override def mapTaskId: Long = _mapTaskId
+
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
+ out.writeLong(_mapTaskId)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
@@ -129,6 +143,7 @@ private[spark] class CompressedMapStatus(
val len = in.readInt()
compressedSizes = new Array[Byte](len)
in.readFully(compressedSizes)
+ _mapTaskId = in.readLong()
}
}
@@ -142,20 +157,23 @@ private[spark] class CompressedMapStatus(
* @param emptyBlocks a bitmap tracking which blocks are empty
* @param avgSize average size of the non-empty and non-huge blocks
* @param hugeBlockSizes sizes of huge blocks by their reduceId.
+ * @param _mapTaskId unique task id for the task
*/
private[spark] class HighlyCompressedMapStatus private (
private[this] var loc: BlockManagerId,
private[this] var numNonEmptyBlocks: Int,
private[this] var emptyBlocks: RoaringBitmap,
private[this] var avgSize: Long,
- private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte])
+ private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte],
+ private[this] var _mapTaskId: Long)
extends MapStatus with Externalizable {
// loc could be null when the default constructor is called during deserialization
- require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0,
+ require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0
+ || numNonEmptyBlocks == 0 || _mapTaskId > 0,
"Average size can only be zero for map stages that produced no output")
- protected def this() = this(null, -1, null, -1, null) // For deserialization only
+ protected def this() = this(null, -1, null, -1, null, -1) // For deserialization only
override def location: BlockManagerId = loc
@@ -171,6 +189,8 @@ private[spark] class HighlyCompressedMapStatus private (
}
}
+ override def mapTaskId: Long = _mapTaskId
+
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
emptyBlocks.writeExternal(out)
@@ -180,6 +200,7 @@ private[spark] class HighlyCompressedMapStatus private (
out.writeInt(kv._1)
out.writeByte(kv._2)
}
+ out.writeLong(_mapTaskId)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
@@ -195,11 +216,15 @@ private[spark] class HighlyCompressedMapStatus private (
hugeBlockSizesImpl(block) = size
}
hugeBlockSizes = hugeBlockSizesImpl
+ _mapTaskId = in.readLong()
}
}
private[spark] object HighlyCompressedMapStatus {
- def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = {
+ def apply(
+ loc: BlockManagerId,
+ uncompressedSizes: Array[Long],
+ mapTaskId: Long): HighlyCompressedMapStatus = {
// We must keep track of which blocks are empty so that we don't report a zero-sized
// block as being non-empty (or vice-versa) when using the average block size.
var i = 0
@@ -240,6 +265,6 @@ private[spark] object HighlyCompressedMapStatus {
emptyBlocks.trim()
emptyBlocks.runOptimize()
new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
- hugeBlockSizes)
+ hugeBlockSizes, mapTaskId)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index f4b0ab10155a2..80805df256a15 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -59,14 +59,14 @@ private[spark] class Pool(
}
}
- override def addSchedulable(schedulable: Schedulable) {
+ override def addSchedulable(schedulable: Schedulable): Unit = {
require(schedulable != null)
schedulableQueue.add(schedulable)
schedulableNameToSchedulable.put(schedulable.name, schedulable)
schedulable.parent = this
}
- override def removeSchedulable(schedulable: Schedulable) {
+ override def removeSchedulable(schedulable: Schedulable): Unit = {
schedulableQueue.remove(schedulable)
schedulableNameToSchedulable.remove(schedulable.name)
}
@@ -84,7 +84,7 @@ private[spark] class Pool(
null
}
- override def executorLost(executorId: String, host: String, reason: ExecutorLossReason) {
+ override def executorLost(executorId: String, host: String, reason: ExecutorLossReason): Unit = {
schedulableQueue.asScala.foreach(_.executorLost(executorId, host, reason))
}
@@ -106,14 +106,14 @@ private[spark] class Pool(
sortedTaskSetQueue
}
- def increaseRunningTasks(taskNum: Int) {
+ def increaseRunningTasks(taskNum: Int): Unit = {
runningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
- def decreaseRunningTasks(taskNum: Int) {
+ def decreaseRunningTasks(taskNum: Int): Unit = {
runningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
index 226c23733c870..699042dd967bc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
import java.io.{EOFException, InputStream, IOException}
-import scala.io.Source
+import scala.io.{Codec, Source}
import com.fasterxml.jackson.core.JsonParseException
import com.fasterxml.jackson.databind.exc.UnrecognizedPropertyException
@@ -54,7 +54,7 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging {
sourceName: String,
maybeTruncated: Boolean = false,
eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = {
- val lines = Source.fromInputStream(logData).getLines()
+ val lines = Source.fromInputStream(logData)(Codec.UTF8).getLines()
replay(lines, sourceName, maybeTruncated, eventsFilter)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
index c85c74f2fb973..8f6a22177a5b8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
@@ -45,11 +45,11 @@ private[spark] trait SchedulableBuilder {
private[spark] class FIFOSchedulableBuilder(val rootPool: Pool)
extends SchedulableBuilder with Logging {
- override def buildPools() {
+ override def buildPools(): Unit = {
// nothing
}
- override def addTaskSetManager(manager: Schedulable, properties: Properties) {
+ override def addTaskSetManager(manager: Schedulable, properties: Properties): Unit = {
rootPool.addSchedulable(manager)
}
}
@@ -70,7 +70,7 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf)
val DEFAULT_MINIMUM_SHARE = 0
val DEFAULT_WEIGHT = 1
- override def buildPools() {
+ override def buildPools(): Unit = {
var fileData: Option[(InputStream, String)] = None
try {
fileData = schedulerAllocFile.map { f =>
@@ -106,7 +106,7 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf)
buildDefaultPool()
}
- private def buildDefaultPool() {
+ private def buildDefaultPool(): Unit = {
if (rootPool.getSchedulableByName(DEFAULT_POOL_NAME) == null) {
val pool = new Pool(DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE,
DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)
@@ -116,7 +116,7 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf)
}
}
- private def buildFairSchedulerPool(is: InputStream, fileName: String) {
+ private def buildFairSchedulerPool(is: InputStream, fileName: String): Unit = {
val xml = XML.load(is)
for (poolNode <- (xml \\ POOLS_PROPERTY)) {
@@ -180,7 +180,7 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf)
}
}
- override def addTaskSetManager(manager: Schedulable, properties: Properties) {
+ override def addTaskSetManager(manager: Schedulable, properties: Properties): Unit = {
val poolName = if (properties != null) {
properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME)
} else {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 710f5eb211dde..06e5d8ab0302a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -91,7 +91,7 @@ private[spark] class ShuffleMapTask(
val rdd = rddAndDep._1
val dep = rddAndDep._2
- dep.shuffleWriterProcessor.write(rdd, dep, partitionId, context, partition)
+ dep.shuffleWriterProcessor.write(rdd, dep, context, partition)
}
override def preferredLocations: Seq[TaskLocation] = preferredLocs
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 26cca334d3bd5..a9f72eae71368 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.HashSet
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{DeterministicLevel, RDD}
import org.apache.spark.util.CallSite
/**
@@ -116,4 +116,8 @@ private[scheduler] abstract class Stage(
/** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */
def findMissingPartitions(): Seq[Int]
+
+ def isIndeterminate: Boolean = {
+ rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index e3216151462bd..fdc50328b43d8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -52,7 +52,7 @@ class StageInfo(
*/
val accumulables = HashMap[Long, AccumulableInfo]()
- def stageFailed(reason: String) {
+ def stageFailed(reason: String): Unit = {
failureReason = Some(reason)
completionTime = Some(System.currentTimeMillis)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala
index 3c7af4f6146fa..ca48775e77f27 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala
@@ -36,7 +36,7 @@ class StatsReportListener extends SparkListener with Logging {
private val taskInfoMetrics = mutable.Buffer[(TaskInfo, TaskMetrics)]()
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val info = taskEnd.taskInfo
val metrics = taskEnd.taskMetrics
if (info != null && metrics != null) {
@@ -44,7 +44,7 @@ class StatsReportListener extends SparkListener with Logging {
}
}
- override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
implicit val sc = stageCompleted
this.logInfo(s"Finished stage: ${getStatusDetail(stageCompleted.stageInfo)}")
showMillisDistribution("task runtime:", (info, _) => info.duration, taskInfoMetrics)
@@ -108,7 +108,7 @@ private[spark] object StatsReportListener extends Logging {
(info, metric) => { getMetric(info, metric).toDouble })
}
- def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) {
+ def showDistribution(heading: String, d: Distribution, formatNumber: Double => String): Unit = {
val stats = d.statCounter
val quantiles = d.getQuantiles(probabilities).map(formatNumber)
logInfo(heading + stats)
@@ -119,11 +119,11 @@ private[spark] object StatsReportListener extends Logging {
def showDistribution(
heading: String,
dOpt: Option[Distribution],
- formatNumber: Double => String) {
+ formatNumber: Double => String): Unit = {
dOpt.foreach { d => showDistribution(heading, d, formatNumber)}
}
- def showDistribution(heading: String, dOpt: Option[Distribution], format: String) {
+ def showDistribution(heading: String, dOpt: Option[Distribution], format: String): Unit = {
def f(d: Double): String = format.format(d)
showDistribution(heading, dOpt, f _)
}
@@ -132,26 +132,26 @@ private[spark] object StatsReportListener extends Logging {
heading: String,
format: String,
getMetric: (TaskInfo, TaskMetrics) => Double,
- taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]): Unit = {
showDistribution(heading, extractDoubleDistribution(taskInfoMetrics, getMetric), format)
}
def showBytesDistribution(
heading: String,
getMetric: (TaskInfo, TaskMetrics) => Long,
- taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]): Unit = {
showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric))
}
- def showBytesDistribution(heading: String, dOpt: Option[Distribution]) {
+ def showBytesDistribution(heading: String, dOpt: Option[Distribution]): Unit = {
dOpt.foreach { dist => showBytesDistribution(heading, dist) }
}
- def showBytesDistribution(heading: String, dist: Distribution) {
+ def showBytesDistribution(heading: String, dist: Distribution): Unit = {
showDistribution(heading, dist, (d => Utils.bytesToString(d.toLong)): Double => String)
}
- def showMillisDistribution(heading: String, dOpt: Option[Distribution]) {
+ def showMillisDistribution(heading: String, dOpt: Option[Distribution]): Unit = {
showDistribution(heading, dOpt,
(d => StatsReportListener.millisToString(d.toLong)): Double => String)
}
@@ -159,7 +159,7 @@ private[spark] object StatsReportListener extends Logging {
def showMillisDistribution(
heading: String,
getMetric: (TaskInfo, TaskMetrics) => Long,
- taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]): Unit = {
showMillisDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric))
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 01828f860bd5e..ebc1c05435fee 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -225,7 +225,7 @@ private[spark] abstract class Task[T](
* be called multiple times.
* If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread.
*/
- def kill(interruptThread: Boolean, reason: String) {
+ def kill(interruptThread: Boolean, reason: String): Unit = {
require(reason != null)
_reasonIfKilled = reason
if (context != null) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 9843eab4f1346..921562bd15dae 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -70,11 +70,11 @@ class TaskInfo(
var killed = false
- private[spark] def markGettingResult(time: Long) {
+ private[spark] def markGettingResult(time: Long): Unit = {
gettingResultTime = time
}
- private[spark] def markFinished(state: TaskState, time: Long) {
+ private[spark] def markFinished(state: TaskState, time: Long): Unit = {
// finishTime should be set larger than 0, otherwise "finished" below will return false.
assert(time > 0)
finishTime = time
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 9b7f901c55e00..6c3d2a4ee3125 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -64,6 +64,9 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match {
case directResult: DirectTaskResult[_] =>
if (!taskSetManager.canFetchMoreResults(serializedData.limit())) {
+ // kill the task so that it will not become zombie task
+ scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled(
+ "Tasks result size has exceeded maxResultSize"))
return
}
// deserialize "value" without holding any lock so that it won't block other threads.
@@ -75,6 +78,9 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
if (!taskSetManager.canFetchMoreResults(size)) {
// dropped by executor if size is larger than maxResultSize
sparkEnv.blockManager.master.removeBlock(blockId)
+ // kill the task so that it will not become zombie task
+ scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled(
+ "Tasks result size has exceeded maxResultSize"))
return
}
logDebug("Fetching indirect task result for TID %s".format(tid))
@@ -125,7 +131,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
}
def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState,
- serializedData: ByteBuffer) {
+ serializedData: ByteBuffer): Unit = {
var reason : TaskFailedReason = UnknownReason
try {
getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions {
@@ -164,7 +170,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
})
}
- def stop() {
+ def stop(): Unit = {
getTaskResultExecutor.shutdownNow()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 8c73d563043c2..15f5d20e9be75 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -46,7 +46,7 @@ private[spark] trait TaskScheduler {
// Invoked after system has successfully initialized (typically in spark context).
// Yarn uses this to bootstrap allocation of resources based on preferred locations,
// wait for slave registrations, etc.
- def postStartHook() { }
+ def postStartHook(): Unit = { }
// Disconnect from the cluster.
def stop(): Unit
@@ -72,7 +72,7 @@ private[spark] trait TaskScheduler {
// Notify the corresponding `TaskSetManager`s of the stage, that a partition has already completed
// and they can skip running tasks for it.
- def notifyPartitionCompletion(stageId: Int, partitionId: Int)
+ def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit
// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
def setDAGScheduler(dagScheduler: DAGScheduler): Unit
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 1496dff31a4dc..f25a36c7af22a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -170,11 +170,11 @@ private[spark] class TaskSchedulerImpl(
}
}
- override def setDAGScheduler(dagScheduler: DAGScheduler) {
+ override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {
this.dagScheduler = dagScheduler
}
- def initialize(backend: SchedulerBackend) {
+ def initialize(backend: SchedulerBackend): Unit = {
this.backend = backend
schedulableBuilder = {
schedulingMode match {
@@ -192,7 +192,7 @@ private[spark] class TaskSchedulerImpl(
def newTaskId(): Long = nextTaskId.getAndIncrement()
- override def start() {
+ override def start(): Unit = {
backend.start()
if (!isLocal && conf.get(SPECULATION_ENABLED)) {
@@ -203,11 +203,11 @@ private[spark] class TaskSchedulerImpl(
}
}
- override def postStartHook() {
+ override def postStartHook(): Unit = {
waitBackendReady()
}
- override def submitTasks(taskSet: TaskSet) {
+ override def submitTasks(taskSet: TaskSet): Unit = {
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
@@ -233,7 +233,7 @@ private[spark] class TaskSchedulerImpl(
if (!isLocal && !hasReceivedTask) {
starvationTimer.scheduleAtFixedRate(new TimerTask() {
- override def run() {
+ override def run(): Unit = {
if (!hasLaunchedTask) {
logWarning("Initial job has not accepted any resources; " +
"check your cluster UI to ensure that workers are registered " +
@@ -430,7 +430,6 @@ private[spark] class TaskSchedulerImpl(
val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK))
val availableResources = shuffledOffers.map(_.resources).toArray
val availableCpus = shuffledOffers.map(o => o.cores).toArray
- val availableSlots = shuffledOffers.map(o => o.cores / CPUS_PER_TASK).sum
val sortedTaskSets = rootPool.getSortedTaskSetQueue
for (taskSet <- sortedTaskSets) {
logDebug("parentName: %s, name: %s, runningTasks: %s".format(
@@ -444,6 +443,7 @@ private[spark] class TaskSchedulerImpl(
// of locality levels so that it gets a chance to launch local tasks on all of them.
// NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY
for (taskSet <- sortedTaskSets) {
+ val availableSlots = availableCpus.map(c => c / CPUS_PER_TASK).sum
// Skip the barrier taskSet if the available slots are less than the number of pending tasks.
if (taskSet.isBarrier && availableSlots < taskSet.numTasks) {
// Skip the launch process.
@@ -572,7 +572,7 @@ private[spark] class TaskSchedulerImpl(
Random.shuffle(offers)
}
- def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer): Unit = {
var failedExecutor: Option[String] = None
var reason: Option[ExecutorLossReason] = None
synchronized {
@@ -681,7 +681,7 @@ private[spark] class TaskSchedulerImpl(
})
}
- def error(message: String) {
+ def error(message: String): Unit = {
synchronized {
if (taskSetsByStageIdAndAttempt.nonEmpty) {
// Have each task set throw a SparkException with the error
@@ -704,7 +704,7 @@ private[spark] class TaskSchedulerImpl(
}
}
- override def stop() {
+ override def stop(): Unit = {
speculationScheduler.shutdown()
if (backend != null) {
backend.stop()
@@ -722,7 +722,7 @@ private[spark] class TaskSchedulerImpl(
override def defaultParallelism(): Int = backend.defaultParallelism()
// Check for speculatable tasks in all our active jobs.
- def checkSpeculatableTasks() {
+ def checkSpeculatableTasks(): Unit = {
var shouldRevive = false
synchronized {
shouldRevive = rootPool.checkSpeculatableTasks(MIN_TIME_TO_SPECULATION)
@@ -798,7 +798,7 @@ private[spark] class TaskSchedulerImpl(
* reason is not yet known, do not yet remove its association with its host nor update the status
* of any running tasks, since the loss reason defines whether we'll fail those tasks.
*/
- private def removeExecutor(executorId: String, reason: ExecutorLossReason) {
+ private def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = {
// The tasks on the lost executor may not send any more status updates (because the executor
// has been lost), so they should be cleaned up here.
executorIdToRunningTaskIds.remove(executorId).foreach { taskIds =>
@@ -829,7 +829,7 @@ private[spark] class TaskSchedulerImpl(
blacklistTrackerOpt.foreach(_.handleRemovedExecutor(executorId))
}
- def executorAdded(execId: String, host: String) {
+ def executorAdded(execId: String, host: String): Unit = {
dagScheduler.executorAdded(execId, host)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala
index b680979a466a5..4df2889089ee9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala
@@ -69,7 +69,6 @@ private[scheduler] class TaskSetBlacklist(
/**
* Get the most recent failure reason of this TaskSet.
- * @return
*/
def getLatestFailureReason: String = {
latestFailureReason
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 49bd55e553482..9defbefabb86a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -474,7 +474,7 @@ private[spark] class TaskSetManager(
}
}
- private def maybeFinishTaskSet() {
+ private def maybeFinishTaskSet(): Unit = {
if (isZombie && runningTasks == 0) {
sched.taskSetFinished(this)
if (tasksSuccessful == numTasks) {
@@ -758,7 +758,7 @@ private[spark] class TaskSetManager(
* Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
* DAG Scheduler.
*/
- def handleFailedTask(tid: Long, state: TaskState, reason: TaskFailedReason) {
+ def handleFailedTask(tid: Long, state: TaskState, reason: TaskFailedReason): Unit = {
val info = taskInfos(tid)
if (info.failed || info.killed) {
return
@@ -886,14 +886,14 @@ private[spark] class TaskSetManager(
*
* Used to keep track of the number of running tasks, for enforcing scheduling policies.
*/
- def addRunningTask(tid: Long) {
+ def addRunningTask(tid: Long): Unit = {
if (runningTasksSet.add(tid) && parent != null) {
parent.increaseRunningTasks(1)
}
}
/** If the given task ID is in the set of running tasks, removes it. */
- def removeRunningTask(tid: Long) {
+ def removeRunningTask(tid: Long): Unit = {
if (runningTasksSet.remove(tid) && parent != null) {
parent.decreaseRunningTasks(1)
}
@@ -903,9 +903,9 @@ private[spark] class TaskSetManager(
null
}
- override def addSchedulable(schedulable: Schedulable) {}
+ override def addSchedulable(schedulable: Schedulable): Unit = {}
- override def removeSchedulable(schedulable: Schedulable) {}
+ override def removeSchedulable(schedulable: Schedulable): Unit = {}
override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
val sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]()
@@ -914,7 +914,7 @@ private[spark] class TaskSetManager(
}
/** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */
- override def executorLost(execId: String, host: String, reason: ExecutorLossReason) {
+ override def executorLost(execId: String, host: String, reason: ExecutorLossReason): Unit = {
// Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage,
// and we are not using an external shuffle server which could serve the shuffle outputs.
// The reason is the next stage wouldn't be able to fetch the data from this dead executor
@@ -1035,14 +1035,14 @@ private[spark] class TaskSetManager(
levels.toArray
}
- def recomputeLocality() {
+ def recomputeLocality(): Unit = {
val previousLocalityLevel = myLocalityLevels(currentLocalityIndex)
myLocalityLevels = computeValidLocalityLevels()
localityWaits = myLocalityLevels.map(getLocalityWait)
currentLocalityIndex = getLocalityIndex(previousLocalityLevel)
}
- def executorAdded() {
+ def executorAdded(): Unit = {
recomputeLocality()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index d81070c362ba6..4958389ae4257 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -68,10 +68,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
conf.get(SCHEDULER_MAX_REGISTERED_RESOURCE_WAITING_TIME))
private val createTimeNs = System.nanoTime()
- // Accessing `executorDataMap` in `DriverEndpoint.receive/receiveAndReply` doesn't need any
- // protection. But accessing `executorDataMap` out of `DriverEndpoint.receive/receiveAndReply`
- // must be protected by `CoarseGrainedSchedulerBackend.this`. Besides, `executorDataMap` should
- // only be modified in `DriverEndpoint.receive/receiveAndReply` with protection by
+ // Accessing `executorDataMap` in the inherited methods from ThreadSafeRpcEndpoint doesn't need
+ // any protection. But accessing `executorDataMap` out of the inherited methods must be
+ // protected by `CoarseGrainedSchedulerBackend.this`. Besides, `executorDataMap` should only
+ // be modified in the inherited methods from ThreadSafeRpcEndpoint with protection by
// `CoarseGrainedSchedulerBackend.this`.
private val executorDataMap = new HashMap[String, ExecutorData]
@@ -129,7 +129,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
private val logUrlHandler: ExecutorLogUrlHandler = new ExecutorLogUrlHandler(
conf.get(UI.CUSTOM_EXECUTOR_LOG_URL))
- override def onStart() {
+ override def onStart(): Unit = {
// Periodically revive offers to allow delay scheduling to work
val reviveIntervalMs = conf.get(SCHEDULER_REVIVE_INTERVAL).getOrElse(1000L)
@@ -263,7 +263,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
// Make fake resource offers on all executors
- private def makeOffers() {
+ private def makeOffers(): Unit = {
// Make sure no executor is killed while some task is launching on it
val taskDescs = withLock {
// Filter out executors under killing
@@ -292,7 +292,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
// Make fake resource offers on just one executor
- private def makeOffers(executorId: String) {
+ private def makeOffers(executorId: String): Unit = {
// Make sure no executor is killed while some task is launching on it
val taskDescs = withLock {
// Filter out executors under killing
@@ -320,7 +320,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
// Launch tasks returned by a set of resource offers
- private def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
+ private def launchTasks(tasks: Seq[Seq[TaskDescription]]): Unit = {
for (task <- tasks.flatten) {
val serializedTask = TaskDescription.encode(task)
if (serializedTask.limit() >= maxRpcMessageSize) {
@@ -420,19 +420,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
protected def minRegisteredRatio: Double = _minRegisteredRatio
- override def start() {
+ override def start(): Unit = {
if (UserGroupInformation.isSecurityEnabled()) {
delegationTokenManager = createTokenManager()
delegationTokenManager.foreach { dtm =>
val ugi = UserGroupInformation.getCurrentUser()
val tokens = if (dtm.renewalEnabled) {
dtm.start()
- } else if (ugi.hasKerberosCredentials() || SparkHadoopUtil.get.isProxyUser(ugi)) {
+ } else {
val creds = ugi.getCredentials()
dtm.obtainDelegationTokens(creds)
- SparkHadoopUtil.get.serialize(creds)
- } else {
- null
+ if (creds.numberOfTokens() > 0 || creds.numberOfSecretKeys() > 0) {
+ SparkHadoopUtil.get.serialize(creds)
+ } else {
+ null
+ }
}
if (tokens != null) {
updateDelegationTokens(tokens)
@@ -443,7 +445,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
protected def createDriverEndpoint(): DriverEndpoint = new DriverEndpoint()
- def stopExecutors() {
+ def stopExecutors(): Unit = {
try {
if (driverEndpoint != null) {
logInfo("Shutting down all executors")
@@ -455,7 +457,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
}
- override def stop() {
+ override def stop(): Unit = {
reviveThread.shutdownNow()
stopExecutors()
delegationTokenManager.foreach(_.stop())
@@ -488,12 +490,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
}
- override def reviveOffers() {
+ override def reviveOffers(): Unit = {
driverEndpoint.send(ReviveOffers)
}
override def killTask(
- taskId: Long, executorId: String, interruptThread: Boolean, reason: String) {
+ taskId: Long, executorId: String, interruptThread: Boolean, reason: String): Unit = {
driverEndpoint.send(KillTask(taskId, executorId, interruptThread, reason))
}
@@ -533,9 +535,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
/**
* Return the number of executors currently registered with this backend.
*/
- private def numExistingExecutors: Int = executorDataMap.size
+ private def numExistingExecutors: Int = synchronized { executorDataMap.size }
- override def getExecutorIds(): Seq[String] = {
+ override def getExecutorIds(): Seq[String] = synchronized {
executorDataMap.keySet.toSeq
}
@@ -543,14 +545,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
executorDataMap.contains(id) && !executorsPendingToRemove.contains(id)
}
- override def maxNumConcurrentTasks(): Int = {
+ override def maxNumConcurrentTasks(): Int = synchronized {
executorDataMap.values.map { executor =>
executor.totalCores / scheduler.CPUS_PER_TASK
}.sum
}
// this function is for testing only
- def getExecutorAvailableResources(executorId: String): Map[String, ExecutorResourceInfo] = {
+ def getExecutorAvailableResources(
+ executorId: String): Map[String, ExecutorResourceInfo] = synchronized {
executorDataMap.get(executorId).map(_.resourcesInfo).getOrElse(Map.empty)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
index 2025a7dc24821..a9b607d8cc38c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -59,7 +59,7 @@ private[spark] class StandaloneSchedulerBackend(
private val maxCores = conf.get(config.CORES_MAX)
private val totalExpectedCores = maxCores.getOrElse(0)
- override def start() {
+ override def start(): Unit = {
super.start()
// SPARK-21159. The scheduler backend should only try to connect to the launcher when in client
@@ -129,21 +129,21 @@ private[spark] class StandaloneSchedulerBackend(
stop(SparkAppHandle.State.FINISHED)
}
- override def connected(appId: String) {
+ override def connected(appId: String): Unit = {
logInfo("Connected to Spark cluster with app ID " + appId)
this.appId = appId
notifyContext()
launcherBackend.setAppId(appId)
}
- override def disconnected() {
+ override def disconnected(): Unit = {
notifyContext()
if (!stopping.get) {
logWarning("Disconnected from Spark cluster! Waiting for reconnection...")
}
}
- override def dead(reason: String) {
+ override def dead(reason: String): Unit = {
notifyContext()
if (!stopping.get) {
launcherBackend.setState(SparkAppHandle.State.KILLED)
@@ -158,13 +158,13 @@ private[spark] class StandaloneSchedulerBackend(
}
override def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int,
- memory: Int) {
+ memory: Int): Unit = {
logInfo("Granted executor ID %s on hostPort %s with %d core(s), %s RAM".format(
fullId, hostPort, cores, Utils.megabytesToString(memory)))
}
override def executorRemoved(
- fullId: String, message: String, exitStatus: Option[Int], workerLost: Boolean) {
+ fullId: String, message: String, exitStatus: Option[Int], workerLost: Boolean): Unit = {
val reason: ExecutorLossReason = exitStatus match {
case Some(code) => ExecutorExited(code, exitCausedByApp = true, message)
case None => SlaveLost(message, workerLost = workerLost)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala
index cbcc5310a59f0..d2c0dc88d987e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala
@@ -29,6 +29,7 @@ import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle}
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.ExecutorInfo
+import org.apache.spark.util.Utils
private case class ReviveOffers()
@@ -54,7 +55,7 @@ private[spark] class LocalEndpoint(
private var freeCores = totalCores
val localExecutorId = SparkContext.DRIVER_IDENTIFIER
- val localExecutorHostname = "localhost"
+ val localExecutorHostname = Utils.localCanonicalHostName()
private val executor = new Executor(
localExecutorId, localExecutorHostname, SparkEnv.get, userClassPath, isLocal = true)
@@ -80,7 +81,7 @@ private[spark] class LocalEndpoint(
context.reply(true)
}
- def reviveOffers() {
+ def reviveOffers(): Unit = {
// local mode doesn't support extra resources like GPUs right now
val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores,
Some(rpcEnv.address.hostPort)))
@@ -123,7 +124,7 @@ private[spark] class LocalSchedulerBackend(
launcherBackend.connect()
- override def start() {
+ override def start(): Unit = {
val rpcEnv = SparkEnv.get.rpcEnv
val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores)
localEndpoint = rpcEnv.setupEndpoint("LocalSchedulerBackendEndpoint", executorEndpoint)
@@ -136,11 +137,11 @@ private[spark] class LocalSchedulerBackend(
launcherBackend.setState(SparkAppHandle.State.RUNNING)
}
- override def stop() {
+ override def stop(): Unit = {
stop(SparkAppHandle.State.FINISHED)
}
- override def reviveOffers() {
+ override def reviveOffers(): Unit = {
localEndpoint.send(ReviveOffers)
}
@@ -148,11 +149,11 @@ private[spark] class LocalSchedulerBackend(
scheduler.conf.getInt("spark.default.parallelism", totalCores)
override def killTask(
- taskId: Long, executorId: String, interruptThread: Boolean, reason: String) {
+ taskId: Long, executorId: String, interruptThread: Boolean, reason: String): Unit = {
localEndpoint.send(KillTask(taskId, interruptThread, reason))
}
- override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
+ override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer): Unit = {
localEndpoint.send(StatusUpdate(taskId, state, serializedData))
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 70564eeefda88..077b035f3d079 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -54,8 +54,8 @@ private[spark] class JavaSerializationStream(
this
}
- def flush() { objOut.flush() }
- def close() { objOut.close() }
+ def flush(): Unit = { objOut.flush() }
+ def close(): Unit = { objOut.close() }
}
private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
@@ -74,7 +74,7 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa
}
def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T]
- def close() { objIn.close() }
+ def close(): Unit = { objIn.close() }
}
private object JavaDeserializationStream {
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 20774c8d999c1..6efb8b35733ef 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -259,14 +259,14 @@ class KryoSerializationStream(
this
}
- override def flush() {
+ override def flush(): Unit = {
if (output == null) {
throw new IOException("Stream is closed")
}
output.flush()
}
- override def close() {
+ override def close(): Unit = {
if (output != null) {
try {
output.close()
@@ -301,7 +301,7 @@ class KryoDeserializationStream(
}
}
- override def close() {
+ override def close(): Unit = {
if (input != null) {
try {
// Kryo's Input automatically closes the input stream it is using.
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
index 5e7a98c8aa89c..75dc3982ab872 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
@@ -303,7 +303,7 @@ private[spark] object SerializationDebugger extends Logging {
/** An output stream that emulates /dev/null */
private class NullOutputStream extends OutputStream {
- override def write(b: Int) { }
+ override def write(b: Int): Unit = { }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index cb8b1cc077637..0c53a84af6e2f 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -173,7 +173,7 @@ abstract class DeserializationStream extends Closeable {
}
}
- override protected def close() {
+ override protected def close(): Unit = {
DeserializationStream.this.close()
}
}
@@ -193,7 +193,7 @@ abstract class DeserializationStream extends Closeable {
}
}
- override protected def close() {
+ override protected def close(): Unit = {
DeserializationStream.this.close()
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala
index 04e4cf88d7063..6fe183c078089 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala
@@ -24,6 +24,5 @@ import org.apache.spark.ShuffleDependency
*/
private[spark] class BaseShuffleHandle[K, V, C](
shuffleId: Int,
- val numMaps: Int,
val dependency: ShuffleDependency[K, V, C])
extends ShuffleHandle(shuffleId)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index 4329824b1b627..8a0e84d901c2f 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -47,7 +47,8 @@ private[spark] class BlockStoreShuffleReader[K, C](
context,
blockManager.blockStoreClient,
blockManager,
- mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
+ mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition,
+ SparkEnv.get.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)),
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
index 265a8acfa8d61..6509a04dc4893 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
@@ -35,7 +35,8 @@ import org.apache.spark.util.Utils
private[spark] class FetchFailedException(
bmAddress: BlockManagerId,
shuffleId: Int,
- mapId: Int,
+ mapId: Long,
+ mapIndex: Int,
reduceId: Int,
message: String,
cause: Throwable = null)
@@ -44,10 +45,11 @@ private[spark] class FetchFailedException(
def this(
bmAddress: BlockManagerId,
shuffleId: Int,
- mapId: Int,
+ mapTaskId: Long,
+ mapIndex: Int,
reduceId: Int,
cause: Throwable) {
- this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause)
+ this(bmAddress, shuffleId, mapTaskId, mapIndex, reduceId, cause.getMessage, cause)
}
// SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code
@@ -56,8 +58,8 @@ private[spark] class FetchFailedException(
// because the TaskContext is not defined in some test cases.
Option(TaskContext.get()).map(_.setFetchFailed(this))
- def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
- Utils.exceptionString(this))
+ def toTaskFailedReason: TaskFailedReason = FetchFailed(
+ bmAddress, shuffleId, mapId, mapIndex, reduceId, Utils.exceptionString(this))
}
/**
@@ -67,4 +69,4 @@ private[spark] class MetadataFetchFailedException(
shuffleId: Int,
reduceId: Int,
message: String)
- extends FetchFailedException(null, shuffleId, -1, reduceId, message)
+ extends FetchFailedException(null, shuffleId, -1L, -1, reduceId, message)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index d3f1c7ec1bbee..332164a7be3e7 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -51,18 +51,18 @@ private[spark] class IndexShuffleBlockResolver(
private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
- def getDataFile(shuffleId: Int, mapId: Int): File = {
+ def getDataFile(shuffleId: Int, mapId: Long): File = {
blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID))
}
- private def getIndexFile(shuffleId: Int, mapId: Int): File = {
+ private def getIndexFile(shuffleId: Int, mapId: Long): File = {
blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID))
}
/**
* Remove data file and index file that contain the output data from one map.
*/
- def removeDataByMap(shuffleId: Int, mapId: Int): Unit = {
+ def removeDataByMap(shuffleId: Int, mapId: Long): Unit = {
var file = getDataFile(shuffleId, mapId)
if (file.exists()) {
if (!file.delete()) {
@@ -135,7 +135,7 @@ private[spark] class IndexShuffleBlockResolver(
*/
def writeIndexFileAndCommit(
shuffleId: Int,
- mapId: Int,
+ mapId: Long,
lengths: Array[Long],
dataTmp: File): Unit = {
val indexFile = getIndexFile(shuffleId, mapId)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
index 18a743fbfa6fc..a717ef242ea7c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
@@ -34,13 +34,12 @@ private[spark] trait ShuffleManager {
*/
def registerShuffle[K, V, C](
shuffleId: Int,
- numMaps: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle
/** Get a writer for a given partition. Called on executors by map tasks. */
def getWriter[K, V](
handle: ShuffleHandle,
- mapId: Int,
+ mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V]
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala
index a988c5e126a76..e0affb858c359 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala
@@ -21,7 +21,7 @@ import java.io.{Closeable, IOException, OutputStream}
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.api.ShufflePartitionWriter
-import org.apache.spark.storage.BlockId
+import org.apache.spark.storage.{BlockId, TimeTrackingOutputStream}
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.PairsWriter
@@ -39,6 +39,7 @@ private[spark] class ShufflePartitionPairsWriter(
private var isClosed = false
private var partitionStream: OutputStream = _
+ private var timeTrackingStream: OutputStream = _
private var wrappedStream: OutputStream = _
private var objOut: SerializationStream = _
private var numRecordsWritten = 0
@@ -59,7 +60,8 @@ private[spark] class ShufflePartitionPairsWriter(
private def open(): Unit = {
try {
partitionStream = partitionWriter.openStream
- wrappedStream = serializerManager.wrapStream(blockId, partitionStream)
+ timeTrackingStream = new TimeTrackingOutputStream(writeMetrics, partitionStream)
+ wrappedStream = serializerManager.wrapStream(blockId, timeTrackingStream)
objOut = serializerInstance.serializeStream(wrappedStream)
} catch {
case e: Exception =>
@@ -78,6 +80,7 @@ private[spark] class ShufflePartitionPairsWriter(
// Setting these to null will prevent the underlying streams from being closed twice
// just in case any stream's close() implementation is not idempotent.
wrappedStream = null
+ timeTrackingStream = null
partitionStream = null
} {
// Normally closing objOut would close the inner streams as well, but just in case there
@@ -86,9 +89,15 @@ private[spark] class ShufflePartitionPairsWriter(
wrappedStream = closeIfNonNull(wrappedStream)
// Same as above - if wrappedStream closes then assume it closes underlying
// partitionStream and don't close again in the finally
+ timeTrackingStream = null
partitionStream = null
} {
- partitionStream = closeIfNonNull(partitionStream)
+ Utils.tryWithSafeFinally {
+ timeTrackingStream = closeIfNonNull(timeTrackingStream)
+ partitionStream = null
+ } {
+ partitionStream = closeIfNonNull(partitionStream)
+ }
}
}
updateBytesWritten()
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala
index 5b0c7e9f2b0b4..f222200a7816c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala
@@ -44,7 +44,6 @@ private[spark] class ShuffleWriteProcessor extends Serializable with Logging {
def write(
rdd: RDD[_],
dep: ShuffleDependency[_, _, _],
- partitionId: Int,
context: TaskContext,
partition: Partition): MapStatus = {
var writer: ShuffleWriter[Any, Any] = null
@@ -52,7 +51,7 @@ private[spark] class ShuffleWriteProcessor extends Serializable with Logging {
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](
dep.shuffleHandle,
- partitionId,
+ context.taskAttemptId(),
context,
createMetricsReporter(context))
writer.write(
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 2a99c93b32af4..d96bcb3d073df 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -24,6 +24,7 @@ import org.apache.spark.internal.{config, Logging}
import org.apache.spark.shuffle._
import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleExecutorComponents}
import org.apache.spark.util.Utils
+import org.apache.spark.util.collection.OpenHashSet
/**
* In sort-based shuffle, incoming records are sorted according to their target partition ids, then
@@ -79,9 +80,9 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
}
/**
- * A mapping from shuffle ids to the number of mappers producing output for those shuffles.
+ * A mapping from shuffle ids to the task ids of mappers producing output for those shuffles.
*/
- private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]()
+ private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]()
private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf)
@@ -92,7 +93,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
*/
override def registerShuffle[K, V, C](
shuffleId: Int,
- numMaps: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
@@ -101,14 +101,14 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
// together the spilled files, which would happen with the normal code path. The downside is
// having multiple files open at a time and thus more memory allocated to buffers.
new BypassMergeSortShuffleHandle[K, V](
- shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
// Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
new SerializedShuffleHandle[K, V](
- shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else {
// Otherwise, buffer map outputs in a deserialized form:
- new BaseShuffleHandle(shuffleId, numMaps, dependency)
+ new BaseShuffleHandle(shuffleId, dependency)
}
}
@@ -130,29 +130,29 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
/** Get a writer for a given partition. Called on executors by map tasks. */
override def getWriter[K, V](
handle: ShuffleHandle,
- mapId: Int,
+ mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
- numMapsForShuffle.putIfAbsent(
- handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
+ val mapTaskIds = taskIdMapsForShuffle.computeIfAbsent(
+ handle.shuffleId, _ => new OpenHashSet[Long](16))
+ mapTaskIds.synchronized { mapTaskIds.add(context.taskAttemptId()) }
val env = SparkEnv.get
handle match {
case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
new UnsafeShuffleWriter(
env.blockManager,
- shuffleBlockResolver,
context.taskMemoryManager(),
unsafeShuffleHandle,
mapId,
context,
env.conf,
- metrics)
+ metrics,
+ shuffleExecutorComponents)
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
bypassMergeSortHandle,
mapId,
- context.taskAttemptId(),
env.conf,
metrics,
shuffleExecutorComponents)
@@ -164,9 +164,9 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
/** Remove a shuffle's metadata from the ShuffleManager. */
override def unregisterShuffle(shuffleId: Int): Boolean = {
- Option(numMapsForShuffle.remove(shuffleId)).foreach { numMaps =>
- (0 until numMaps).foreach { mapId =>
- shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
+ Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { mapTaskIds =>
+ mapTaskIds.iterator.foreach { mapTaskId =>
+ shuffleBlockResolver.removeDataByMap(shuffleId, mapTaskId)
}
}
true
@@ -231,9 +231,8 @@ private[spark] object SortShuffleManager extends Logging {
*/
private[spark] class SerializedShuffleHandle[K, V](
shuffleId: Int,
- numMaps: Int,
dependency: ShuffleDependency[K, V, V])
- extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
+ extends BaseShuffleHandle(shuffleId, dependency) {
}
/**
@@ -242,7 +241,6 @@ private[spark] class SerializedShuffleHandle[K, V](
*/
private[spark] class BypassMergeSortShuffleHandle[K, V](
shuffleId: Int,
- numMaps: Int,
dependency: ShuffleDependency[K, V, V])
- extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
+ extends BaseShuffleHandle(shuffleId, dependency) {
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index a781b16252432..a391bdf2db44e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -27,7 +27,7 @@ import org.apache.spark.util.collection.ExternalSorter
private[spark] class SortShuffleWriter[K, V, C](
shuffleBlockResolver: IndexShuffleBlockResolver,
handle: BaseShuffleHandle[K, V, C],
- mapId: Int,
+ mapId: Long,
context: TaskContext,
shuffleExecutorComponents: ShuffleExecutorComponents)
extends ShuffleWriter[K, V] with Logging {
@@ -65,10 +65,10 @@ private[spark] class SortShuffleWriter[K, V, C](
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
- dep.shuffleId, mapId, context.taskAttemptId(), dep.partitioner.numPartitions)
+ dep.shuffleId, mapId, dep.partitioner.numPartitions)
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
val partitionLengths = mapOutputWriter.commitAllPartitions()
- mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
+ mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
}
/** Close this writer, passing along whether the map completed */
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/PrometheusResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/PrometheusResource.scala
new file mode 100644
index 0000000000000..6e52e213bda8e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/PrometheusResource.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.status.api.v1
+
+import javax.ws.rs._
+import javax.ws.rs.core.MediaType
+
+import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}
+import org.glassfish.jersey.server.ServerProperties
+import org.glassfish.jersey.servlet.ServletContainer
+
+import org.apache.spark.ui.SparkUI
+
+/**
+ * This aims to expose Executor metrics like REST API which is documented in
+ *
+ * https://spark.apache.org/docs/3.0.0/monitoring.html#executor-metrics
+ *
+ * Note that this is based on ExecutorSummary which is different from ExecutorSource.
+ */
+@Path("/executors")
+private[v1] class PrometheusResource extends ApiRequestContext {
+ @GET
+ @Path("prometheus")
+ @Produces(Array(MediaType.TEXT_PLAIN))
+ def executors(): String = {
+ val sb = new StringBuilder
+ val store = uiRoot.asInstanceOf[SparkUI].store
+ val appId = store.applicationInfo.id.replaceAll("[^a-zA-Z0-9]", "_")
+ store.executorList(true).foreach { executor =>
+ val prefix = s"metrics_${appId}_${executor.id}_executor_"
+ sb.append(s"${prefix}rddBlocks_Count ${executor.rddBlocks}\n")
+ sb.append(s"${prefix}memoryUsed_Count ${executor.memoryUsed}\n")
+ sb.append(s"${prefix}diskUsed_Count ${executor.diskUsed}\n")
+ sb.append(s"${prefix}totalCores_Count ${executor.totalCores}\n")
+ sb.append(s"${prefix}maxTasks_Count ${executor.maxTasks}\n")
+ sb.append(s"${prefix}activeTasks_Count ${executor.activeTasks}\n")
+ sb.append(s"${prefix}failedTasks_Count ${executor.failedTasks}\n")
+ sb.append(s"${prefix}completedTasks_Count ${executor.completedTasks}\n")
+ sb.append(s"${prefix}totalTasks_Count ${executor.totalTasks}\n")
+ sb.append(s"${prefix}totalDuration_Value ${executor.totalDuration}\n")
+ sb.append(s"${prefix}totalGCTime_Value ${executor.totalGCTime}\n")
+ sb.append(s"${prefix}totalInputBytes_Count ${executor.totalInputBytes}\n")
+ sb.append(s"${prefix}totalShuffleRead_Count ${executor.totalShuffleRead}\n")
+ sb.append(s"${prefix}totalShuffleWrite_Count ${executor.totalShuffleWrite}\n")
+ sb.append(s"${prefix}maxMemory_Count ${executor.maxMemory}\n")
+ executor.executorLogs.foreach { case (k, v) => }
+ executor.memoryMetrics.foreach { m =>
+ sb.append(s"${prefix}usedOnHeapStorageMemory_Count ${m.usedOnHeapStorageMemory}\n")
+ sb.append(s"${prefix}usedOffHeapStorageMemory_Count ${m.usedOffHeapStorageMemory}\n")
+ sb.append(s"${prefix}totalOnHeapStorageMemory_Count ${m.totalOnHeapStorageMemory}\n")
+ sb.append(s"${prefix}totalOffHeapStorageMemory_Count ${m.totalOffHeapStorageMemory}\n")
+ }
+ executor.peakMemoryMetrics.foreach { m =>
+ val names = Array(
+ "JVMHeapMemory",
+ "JVMOffHeapMemory",
+ "OnHeapExecutionMemory",
+ "OffHeapExecutionMemory",
+ "OnHeapStorageMemory",
+ "OffHeapStorageMemory",
+ "OnHeapUnifiedMemory",
+ "OffHeapUnifiedMemory",
+ "DirectPoolMemory",
+ "MappedPoolMemory",
+ "ProcessTreeJVMVMemory",
+ "ProcessTreeJVMRSSMemory",
+ "ProcessTreePythonVMemory",
+ "ProcessTreePythonRSSMemory",
+ "ProcessTreeOtherVMemory",
+ "ProcessTreeOtherRSSMemory",
+ "MinorGCCount",
+ "MinorGCTime",
+ "MajorGCCount",
+ "MajorGCTime"
+ )
+ names.foreach { name =>
+ sb.append(s"$prefix${name}_Count ${m.getMetricValue(name)}\n")
+ }
+ }
+ }
+ sb.toString
+ }
+}
+
+private[spark] object PrometheusResource {
+ def getServletHandler(uiRoot: UIRoot): ServletContextHandler = {
+ val jerseyContext = new ServletContextHandler(ServletContextHandler.NO_SESSIONS)
+ jerseyContext.setContextPath("/metrics")
+ val holder: ServletHolder = new ServletHolder(classOf[ServletContainer])
+ holder.setInitParameter(ServerProperties.PROVIDER_PACKAGES, "org.apache.spark.status.api.v1")
+ UIRootFromServletContext.setUiRoot(jerseyContext, uiRoot)
+ jerseyContext.addServlet(holder, "/*")
+ jerseyContext
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 7ac2c71c18eb3..9c5b7f64e7abe 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -52,17 +52,17 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
// Format of the shuffle block ids (including data and index) should be kept in sync with
// org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getBlockData().
@DeveloperApi
-case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
+case class ShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId {
override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}
@DeveloperApi
-case class ShuffleDataBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
+case class ShuffleDataBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId {
override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data"
}
@DeveloperApi
-case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
+case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId {
override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index"
}
@@ -117,11 +117,11 @@ object BlockId {
case RDD(rddId, splitIndex) =>
RDDBlockId(rddId.toInt, splitIndex.toInt)
case SHUFFLE(shuffleId, mapId, reduceId) =>
- ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
+ ShuffleBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt)
case SHUFFLE_DATA(shuffleId, mapId, reduceId) =>
- ShuffleDataBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
+ ShuffleDataBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt)
case SHUFFLE_INDEX(shuffleId, mapId, reduceId) =>
- ShuffleIndexBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
+ ShuffleIndexBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt)
case BROADCAST(broadcastId, field) =>
BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
case TASKRESULT(taskId) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 4b71dc1fff345..a7dfc20d15ebc 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -446,7 +446,7 @@ private[spark] class BlockManager(
}
}
- private def registerWithExternalShuffleServer() {
+ private def registerWithExternalShuffleServer(): Unit = {
logInfo("Registering executor with local external shuffle service.")
val shuffleConfig = new ExecutorShuffleInfo(
diskBlockManager.localDirsString,
@@ -853,7 +853,6 @@ private[spark] class BlockManager(
* @param bufferTransformer this transformer expected to open the file if the block is backed by a
* file by this it is guaranteed the whole content can be loaded
* @tparam T result type
- * @return
*/
private[spark] def getRemoteBlock[T](
blockId: BlockId,
@@ -1725,15 +1724,23 @@ private[spark] class BlockManager(
* lock on the block.
*/
private def removeBlockInternal(blockId: BlockId, tellMaster: Boolean): Unit = {
+ val blockStatus = if (tellMaster) {
+ val blockInfo = blockInfoManager.assertBlockIsLockedForWriting(blockId)
+ Some(getCurrentBlockStatus(blockId, blockInfo))
+ } else None
+
// Removals are idempotent in disk store and memory store. At worst, we get a warning.
val removedFromMemory = memoryStore.remove(blockId)
val removedFromDisk = diskStore.remove(blockId)
if (!removedFromMemory && !removedFromDisk) {
logWarning(s"Block $blockId could not be removed as it was not found on disk or in memory")
}
+
blockInfoManager.removeBlock(blockId)
if (tellMaster) {
- reportBlockStatus(blockId, BlockStatus.empty)
+ // Only update storage level from the captured block status before deleting, so that
+ // memory size and disk size are being kept for calculating delta.
+ reportBlockStatus(blockId, blockStatus.get.copy(storageLevel = StorageLevel.NONE))
}
}
@@ -1831,7 +1838,7 @@ private[spark] object BlockManager {
private val POLL_TIMEOUT = 1000
@volatile private var stopped = false
- private val cleaningThread = new Thread() { override def run() { keepCleaning() } }
+ private val cleaningThread = new Thread() { override def run(): Unit = { keepCleaning() } }
cleaningThread.setDaemon(true)
cleaningThread.setName("RemoteBlock-temp-file-clean-thread")
cleaningThread.start()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index d188bdd912e5e..49e32d04d450a 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -27,7 +27,7 @@ import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
- * This class represent an unique identifier for a BlockManager.
+ * This class represent a unique identifier for a BlockManager.
*
* The first 2 constructors of this class are made private to ensure that BlockManagerId objects
* can be created only using the apply method in the companion object. This allows de-duplication
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 9d13fedfb0c58..525304fe3c9d3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -37,7 +37,7 @@ class BlockManagerMaster(
val timeout = RpcUtils.askRpcTimeout(conf)
/** Remove a dead executor from the driver endpoint. This is only called on the driver side. */
- def removeExecutor(execId: String) {
+ def removeExecutor(execId: String): Unit = {
tell(RemoveExecutor(execId))
logInfo("Removed " + execId + " successfully in removeExecutor")
}
@@ -45,7 +45,7 @@ class BlockManagerMaster(
/** Request removal of a dead executor from the driver endpoint.
* This is only called on the driver side. Non-blocking
*/
- def removeExecutorAsync(execId: String) {
+ def removeExecutorAsync(execId: String): Unit = {
driverEndpoint.ask[Boolean](RemoveExecutor(execId))
logInfo("Removal of executor " + execId + " requested")
}
@@ -120,12 +120,12 @@ class BlockManagerMaster(
* Remove a block from the slaves that have it. This can only be used to remove
* blocks that the driver knows about.
*/
- def removeBlock(blockId: BlockId) {
+ def removeBlock(blockId: BlockId): Unit = {
driverEndpoint.askSync[Boolean](RemoveBlock(blockId))
}
/** Remove all blocks belonging to the given RDD. */
- def removeRdd(rddId: Int, blocking: Boolean) {
+ def removeRdd(rddId: Int, blocking: Boolean): Unit = {
val future = driverEndpoint.askSync[Future[Seq[Int]]](RemoveRdd(rddId))
future.failed.foreach(e =>
logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}", e)
@@ -136,7 +136,7 @@ class BlockManagerMaster(
}
/** Remove all blocks belonging to the given shuffle. */
- def removeShuffle(shuffleId: Int, blocking: Boolean) {
+ def removeShuffle(shuffleId: Int, blocking: Boolean): Unit = {
val future = driverEndpoint.askSync[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
future.failed.foreach(e =>
logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}", e)
@@ -147,7 +147,7 @@ class BlockManagerMaster(
}
/** Remove all blocks belonging to the given broadcast. */
- def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) {
+ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean): Unit = {
val future = driverEndpoint.askSync[Future[Seq[Int]]](
RemoveBroadcast(broadcastId, removeFromMaster))
future.failed.foreach(e =>
@@ -226,7 +226,7 @@ class BlockManagerMaster(
}
/** Stop the driver endpoint, called only on the Spark driver node */
- def stop() {
+ def stop(): Unit = {
if (driverEndpoint != null && isDriver) {
tell(StopBlockManagerMaster)
driverEndpoint = null
@@ -235,7 +235,7 @@ class BlockManagerMaster(
}
/** Send a one-way message to the master endpoint, to which we expect it to reply with true. */
- private def tell(message: Any) {
+ private def tell(message: Any): Unit = {
if (!driverEndpoint.askSync[Boolean](message)) {
throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.")
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index 5e021b334fd2b..faf6f713c838f 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -243,7 +243,7 @@ class BlockManagerMasterEndpoint(
Future.sequence(futures)
}
- private def removeBlockManager(blockManagerId: BlockManagerId) {
+ private def removeBlockManager(blockManagerId: BlockManagerId): Unit = {
val info = blockManagerInfo(blockManagerId)
// Remove the block manager from blockManagerIdByExecutor.
@@ -285,7 +285,7 @@ class BlockManagerMasterEndpoint(
}
- private def removeExecutor(execId: String) {
+ private def removeExecutor(execId: String): Unit = {
logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.")
blockManagerIdByExecutor.get(execId).foreach(removeBlockManager)
}
@@ -305,7 +305,7 @@ class BlockManagerMasterEndpoint(
// Remove a block from the slaves that have it. This can only be used to remove
// blocks that the master knows about.
- private def removeBlockFromWorkers(blockId: BlockId) {
+ private def removeBlockFromWorkers(blockId: BlockId): Unit = {
val locations = blockLocations.get(blockId)
if (locations != null) {
locations.foreach { blockManagerId: BlockManagerId =>
@@ -593,7 +593,7 @@ private[spark] class BlockManagerInfo(
def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId))
- def updateLastSeenMs() {
+ def updateLastSeenMs(): Unit = {
_lastSeenMs = System.currentTimeMillis()
}
@@ -601,7 +601,7 @@ private[spark] class BlockManagerInfo(
blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
- diskSize: Long) {
+ diskSize: Long): Unit = {
updateLastSeenMs()
@@ -681,7 +681,7 @@ private[spark] class BlockManagerInfo(
}
}
- def removeBlock(blockId: BlockId) {
+ def removeBlock(blockId: BlockId): Unit = {
if (_blocks.containsKey(blockId)) {
_remainingMem += _blocks.get(blockId).memSize
_blocks.remove(blockId)
@@ -699,7 +699,7 @@ private[spark] class BlockManagerInfo(
override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem
- def clear() {
+ def clear(): Unit = {
_blocks.clear()
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
index 67544b20408a6..f90595ab924b4 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
@@ -80,7 +80,7 @@ class BlockManagerSlaveEndpoint(
}
- private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) {
+ private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T): Unit = {
val future = Future {
logDebug(actionMessage)
body
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index c3990bf71e604..f2113947f6bf5 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -161,7 +161,7 @@ private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolea
}
/** Cleanup local dirs and stop shuffle sender. */
- private[spark] def stop() {
+ private[spark] def stop(): Unit = {
// Remove the shutdown hook. It causes memory leaks if we leave it around.
try {
ShutdownHookManager.removeShutdownHook(shutdownHook)
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index 758621c52495b..e55c09274cd9a 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -18,7 +18,7 @@
package org.apache.spark.storage
import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream}
-import java.nio.channels.FileChannel
+import java.nio.channels.{ClosedByInterruptException, FileChannel}
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
@@ -150,7 +150,7 @@ private[spark] class DiskBlockObjectWriter(
/**
* Commits any remaining partial writes and closes resources.
*/
- override def close() {
+ override def close(): Unit = {
if (initialized) {
Utils.tryWithSafeFinally {
commitAndGet()
@@ -219,6 +219,12 @@ private[spark] class DiskBlockObjectWriter(
truncateStream = new FileOutputStream(file, true)
truncateStream.getChannel.truncate(committedPosition)
} catch {
+ // ClosedByInterruptException is an excepted exception when kill task,
+ // don't log the exception stack trace to avoid confusing users.
+ // See: SPARK-28340
+ case ce: ClosedByInterruptException =>
+ logError("Exception occurred while reverting partial writes to file "
+ + file + ", " + ce.getMessage)
case e: Exception =>
logError("Uncaught exception while reverting partial writes to file " + file, e)
} finally {
@@ -234,7 +240,7 @@ private[spark] class DiskBlockObjectWriter(
/**
* Writes a key-value pair.
*/
- override def write(key: Any, value: Any) {
+ override def write(key: Any, value: Any): Unit = {
if (!streamOpen) {
open()
}
@@ -270,14 +276,14 @@ private[spark] class DiskBlockObjectWriter(
* Report the number of bytes written in this writer's shuffle write metrics.
* Note that this is only valid before the underlying streams are closed.
*/
- private def updateBytesWritten() {
+ private def updateBytesWritten(): Unit = {
val pos = channel.position()
writeMetrics.incBytesWritten(pos - reportedPosition)
reportedPosition = pos
}
// For testing
- private[spark] override def flush() {
+ private[spark] override def flush(): Unit = {
objOut.flush()
bs.flush()
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index a5b7ee5762c49..dce5ebaebbae5 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -18,6 +18,7 @@
package org.apache.spark.storage
import java.io.{InputStream, IOException}
+import java.nio.channels.ClosedByInterruptException
import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
import javax.annotation.concurrent.GuardedBy
@@ -48,9 +49,10 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils}
* @param shuffleClient [[BlockStoreClient]] for fetching remote blocks
* @param blockManager [[BlockManager]] for reading local blocks
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
- * For each block we also require the size (in bytes as a long field) in
- * order to throttle the memory usage. Note that zero-sized blocks are
- * already excluded, which happened in
+ * For each block we also require two info: 1. the size (in bytes as a long
+ * field) in order to throttle the memory usage; 2. the mapIndex for this
+ * block, which indicate the index in the map stage.
+ * Note that zero-sized blocks are already excluded, which happened in
* [[org.apache.spark.MapOutputTracker.convertMapStatuses]].
* @param streamWrapper A function to wrap the returned input stream.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
@@ -66,7 +68,7 @@ final class ShuffleBlockFetcherIterator(
context: TaskContext,
shuffleClient: BlockStoreClient,
blockManager: BlockManager,
- blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])],
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
maxReqsInFlight: Int,
@@ -96,7 +98,7 @@ final class ShuffleBlockFetcherIterator(
private[this] val startTimeNs = System.nanoTime()
/** Local blocks to fetch, excluding zero-sized blocks. */
- private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[BlockId]()
+ private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]()
/** Remote blocks to fetch, excluding zero-sized blocks. */
private[this] val remoteBlocks = new HashSet[BlockId]()
@@ -188,7 +190,7 @@ final class ShuffleBlockFetcherIterator(
/**
* Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
*/
- private[storage] def cleanup() {
+ private[storage] def cleanup(): Unit = {
synchronized {
isZombie = true
}
@@ -198,7 +200,7 @@ final class ShuffleBlockFetcherIterator(
while (iter.hasNext) {
val result = iter.next()
result match {
- case SuccessFetchResult(_, address, _, buf, _) =>
+ case SuccessFetchResult(_, _, address, _, buf, _) =>
if (address != blockManager.blockManagerId) {
shuffleMetrics.incRemoteBytesRead(buf.size)
if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
@@ -217,16 +219,18 @@ final class ShuffleBlockFetcherIterator(
}
}
- private[this] def sendRequest(req: FetchRequest) {
+ private[this] def sendRequest(req: FetchRequest): Unit = {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
bytesInFlight += req.size
reqsInFlight += 1
- // so we can look up the size of each blockID
- val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
- val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
- val blockIds = req.blocks.map(_._1.toString)
+ // so we can look up the block info of each blockID
+ val infoMap = req.blocks.map {
+ case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex))
+ }.toMap
+ val remainingBlocks = new HashSet[String]() ++= infoMap.keys
+ val blockIds = req.blocks.map(_.blockId.toString)
val address = req.address
val blockFetchingListener = new BlockFetchingListener {
@@ -239,8 +243,8 @@ final class ShuffleBlockFetcherIterator(
// This needs to be released after use.
buf.retain()
remainingBlocks -= blockId
- results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
- remainingBlocks.isEmpty))
+ results.put(new SuccessFetchResult(BlockId(blockId), infoMap(blockId)._2,
+ address, infoMap(blockId)._1, buf, remainingBlocks.isEmpty))
logDebug("remainingBlocks: " + remainingBlocks)
}
}
@@ -249,7 +253,7 @@ final class ShuffleBlockFetcherIterator(
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
- results.put(new FailureFetchResult(BlockId(blockId), address, e))
+ results.put(new FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e))
}
}
@@ -282,28 +286,28 @@ final class ShuffleBlockFetcherIterator(
for ((address, blockInfos) <- blocksByAddress) {
if (address.executorId == blockManager.blockManagerId.executorId) {
blockInfos.find(_._2 <= 0) match {
- case Some((blockId, size)) if size < 0 =>
+ case Some((blockId, size, _)) if size < 0 =>
throw new BlockException(blockId, "Negative block size " + size)
- case Some((blockId, size)) if size == 0 =>
+ case Some((blockId, size, _)) if size == 0 =>
throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
case None => // do nothing.
}
- localBlocks ++= blockInfos.map(_._1)
+ localBlocks ++= blockInfos.map(info => (info._1, info._3))
localBlockBytes += blockInfos.map(_._2).sum
numBlocksToFetch += localBlocks.size
} else {
val iterator = blockInfos.iterator
var curRequestSize = 0L
- var curBlocks = new ArrayBuffer[(BlockId, Long)]
+ var curBlocks = new ArrayBuffer[FetchBlockInfo]
while (iterator.hasNext) {
- val (blockId, size) = iterator.next()
+ val (blockId, size, mapIndex) = iterator.next()
remoteBlockBytes += size
if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
} else if (size == 0) {
throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
} else {
- curBlocks += ((blockId, size))
+ curBlocks += FetchBlockInfo(blockId, size, mapIndex)
remoteBlocks += blockId
numBlocksToFetch += 1
curRequestSize += size
@@ -314,7 +318,7 @@ final class ShuffleBlockFetcherIterator(
remoteRequests += new FetchRequest(address, curBlocks)
logDebug(s"Creating fetch request of $curRequestSize at $address "
+ s"with ${curBlocks.size} blocks")
- curBlocks = new ArrayBuffer[(BlockId, Long)]
+ curBlocks = new ArrayBuffer[FetchBlockInfo]
curRequestSize = 0
}
}
@@ -336,23 +340,30 @@ final class ShuffleBlockFetcherIterator(
* `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we
* track in-memory are the ManagedBuffer references themselves.
*/
- private[this] def fetchLocalBlocks() {
+ private[this] def fetchLocalBlocks(): Unit = {
logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}")
val iter = localBlocks.iterator
while (iter.hasNext) {
- val blockId = iter.next()
+ val (blockId, mapIndex) = iter.next()
try {
val buf = blockManager.getBlockData(blockId)
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
- results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId,
+ results.put(new SuccessFetchResult(blockId, mapIndex, blockManager.blockManagerId,
buf.size(), buf, false))
} catch {
+ // If we see an exception, stop immediately.
case e: Exception =>
- // If we see an exception, stop immediately.
- logError(s"Error occurred while fetching local blocks", e)
- results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
+ e match {
+ // ClosedByInterruptException is an excepted exception when kill task,
+ // don't log the exception stack trace to avoid confusing users.
+ // See: SPARK-28340
+ case ce: ClosedByInterruptException =>
+ logError("Error occurred while fetching local blocks, " + ce.getMessage)
+ case ex: Exception => logError("Error occurred while fetching local blocks", ex)
+ }
+ results.put(new FailureFetchResult(blockId, mapIndex, blockManager.blockManagerId, e))
return
}
}
@@ -412,7 +423,7 @@ final class ShuffleBlockFetcherIterator(
shuffleMetrics.incFetchWaitTime(fetchWaitTime)
result match {
- case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
+ case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) =>
if (address != blockManager.blockManagerId) {
numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
shuffleMetrics.incRemoteBytesRead(buf.size)
@@ -421,7 +432,7 @@ final class ShuffleBlockFetcherIterator(
}
shuffleMetrics.incRemoteBlocksFetched(1)
}
- if (!localBlocks.contains(blockId)) {
+ if (!localBlocks.contains((blockId, mapIndex))) {
bytesInFlight -= size
}
if (isNetworkReqDone) {
@@ -445,7 +456,7 @@ final class ShuffleBlockFetcherIterator(
// since the last call.
val msg = s"Received a zero-size buffer for block $blockId from $address " +
s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)"
- throwFetchFailedException(blockId, address, new IOException(msg))
+ throwFetchFailedException(blockId, mapIndex, address, new IOException(msg))
}
val in = try {
@@ -454,9 +465,14 @@ final class ShuffleBlockFetcherIterator(
// The exception could only be throwed by local shuffle block
case e: IOException =>
assert(buf.isInstanceOf[FileSegmentManagedBuffer])
- logError("Failed to create input stream from local block", e)
+ e match {
+ case ce: ClosedByInterruptException =>
+ logError("Failed to create input stream from local block, " +
+ ce.getMessage)
+ case e: IOException => logError("Failed to create input stream from local block", e)
+ }
buf.release()
- throwFetchFailedException(blockId, address, e)
+ throwFetchFailedException(blockId, mapIndex, address, e)
}
try {
input = streamWrapper(blockId, in)
@@ -474,11 +490,12 @@ final class ShuffleBlockFetcherIterator(
buf.release()
if (buf.isInstanceOf[FileSegmentManagedBuffer]
|| corruptedBlocks.contains(blockId)) {
- throwFetchFailedException(blockId, address, e)
+ throwFetchFailedException(blockId, mapIndex, address, e)
} else {
logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
corruptedBlocks += blockId
- fetchRequests += FetchRequest(address, Array((blockId, size)))
+ fetchRequests += FetchRequest(
+ address, Array(FetchBlockInfo(blockId, size, mapIndex)))
result = null
}
} finally {
@@ -490,8 +507,8 @@ final class ShuffleBlockFetcherIterator(
}
}
- case FailureFetchResult(blockId, address, e) =>
- throwFetchFailedException(blockId, address, e)
+ case FailureFetchResult(blockId, mapIndex, address, e) =>
+ throwFetchFailedException(blockId, mapIndex, address, e)
}
// Send fetch requests up to maxBytesInFlight
@@ -504,6 +521,7 @@ final class ShuffleBlockFetcherIterator(
input,
this,
currentResult.blockId,
+ currentResult.mapIndex,
currentResult.address,
detectCorrupt && streamCompressedOrEncrypted))
}
@@ -570,11 +588,12 @@ final class ShuffleBlockFetcherIterator(
private[storage] def throwFetchFailedException(
blockId: BlockId,
+ mapIndex: Int,
address: BlockManagerId,
e: Throwable) = {
blockId match {
case ShuffleBlockId(shufId, mapId, reduceId) =>
- throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
+ throw new FetchFailedException(address, shufId, mapId, mapIndex, reduceId, e)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block", e)
@@ -591,6 +610,7 @@ private class BufferReleasingInputStream(
private[storage] val delegate: InputStream,
private val iterator: ShuffleBlockFetcherIterator,
private val blockId: BlockId,
+ private val mapIndex: Int,
private val address: BlockManagerId,
private val detectCorruption: Boolean)
extends InputStream {
@@ -602,7 +622,7 @@ private class BufferReleasingInputStream(
} catch {
case e: IOException if detectCorruption =>
IOUtils.closeQuietly(this)
- iterator.throwFetchFailedException(blockId, address, e)
+ iterator.throwFetchFailedException(blockId, mapIndex, address, e)
}
}
@@ -624,7 +644,7 @@ private class BufferReleasingInputStream(
} catch {
case e: IOException if detectCorruption =>
IOUtils.closeQuietly(this)
- iterator.throwFetchFailedException(blockId, address, e)
+ iterator.throwFetchFailedException(blockId, mapIndex, address, e)
}
}
@@ -636,7 +656,7 @@ private class BufferReleasingInputStream(
} catch {
case e: IOException if detectCorruption =>
IOUtils.closeQuietly(this)
- iterator.throwFetchFailedException(blockId, address, e)
+ iterator.throwFetchFailedException(blockId, mapIndex, address, e)
}
}
@@ -646,7 +666,7 @@ private class BufferReleasingInputStream(
} catch {
case e: IOException if detectCorruption =>
IOUtils.closeQuietly(this)
- iterator.throwFetchFailedException(blockId, address, e)
+ iterator.throwFetchFailedException(blockId, mapIndex, address, e)
}
}
@@ -677,14 +697,25 @@ private class ShuffleFetchCompletionListener(var data: ShuffleBlockFetcherIterat
private[storage]
object ShuffleBlockFetcherIterator {
+ /**
+ * The block information to fetch used in FetchRequest.
+ * @param blockId block id
+ * @param size estimated size of the block. Note that this is NOT the exact bytes.
+ * Size of remote block is used to calculate bytesInFlight.
+ * @param mapIndex the mapIndex for this block, which indicate the index in the map stage.
+ */
+ private[storage] case class FetchBlockInfo(
+ blockId: BlockId,
+ size: Long,
+ mapIndex: Int)
+
/**
* A request to fetch blocks from a remote BlockManager.
* @param address remote BlockManager to fetch from.
- * @param blocks Sequence of tuple, where the first element is the block id,
- * and the second element is the estimated size, used to calculate bytesInFlight.
+ * @param blocks Sequence of the information for blocks to fetch from the same address.
*/
- case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) {
- val size = blocks.map(_._2).sum
+ case class FetchRequest(address: BlockManagerId, blocks: Seq[FetchBlockInfo]) {
+ val size = blocks.map(_.size).sum
}
/**
@@ -698,6 +729,7 @@ object ShuffleBlockFetcherIterator {
/**
* Result of a fetch from a remote block successfully.
* @param blockId block id
+ * @param mapIndex the mapIndex for this block, which indicate the index in the map stage.
* @param address BlockManager that the block was fetched from.
* @param size estimated size of the block. Note that this is NOT the exact bytes.
* Size of remote block is used to calculate bytesInFlight.
@@ -706,6 +738,7 @@ object ShuffleBlockFetcherIterator {
*/
private[storage] case class SuccessFetchResult(
blockId: BlockId,
+ mapIndex: Int,
address: BlockManagerId,
size: Long,
buf: ManagedBuffer,
@@ -717,11 +750,13 @@ object ShuffleBlockFetcherIterator {
/**
* Result of a fetch from a remote block unsuccessfully.
* @param blockId block id
+ * @param mapIndex the mapIndex for this block, which indicate the index in the map stage
* @param address BlockManager that the block was attempted to be fetched from
* @param e the failure exception
*/
private[storage] case class FailureFetchResult(
blockId: BlockId,
+ mapIndex: Int,
address: BlockManagerId,
e: Throwable)
extends FetchResult
diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
index f36b31c65a63d..5c59859d14f76 100644
--- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
@@ -48,7 +48,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
// Schedule a refresh thread to run periodically
private val timer = new Timer("refresh progress", true)
timer.schedule(new TimerTask{
- override def run() {
+ override def run(): Unit = {
refresh()
}
}, firstDelayMSec, updatePeriodMSec)
@@ -73,7 +73,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
* after your last output, keeps overwriting itself to hold in one line. The logging will follow
* the progress bar, then progress bar will be showed in next line without overwrite logs.
*/
- private def show(now: Long, stages: Seq[StageData]) {
+ private def show(now: Long, stages: Seq[StageData]): Unit = {
val width = TerminalWidth / stages.size
val bar = stages.map { s =>
val total = s.numTasks
@@ -103,7 +103,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
/**
* Clear the progress bar if showed.
*/
- private def clear() {
+ private def clear(): Unit = {
if (!lastProgressBar.isEmpty) {
System.err.printf(CR + " " * TerminalWidth + CR)
lastProgressBar = ""
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index ff7baf4d9419b..cd4104731d400 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -73,7 +73,7 @@ private[spark] object JettyUtils extends Logging {
servletParams: ServletParams[T],
conf: SparkConf): HttpServlet = {
new HttpServlet {
- override def doGet(request: HttpServletRequest, response: HttpServletResponse) {
+ override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = {
try {
response.setContentType("%s;charset=utf-8".format(servletParams.contentType))
response.setStatus(HttpServletResponse.SC_OK)
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index 6fb8e458a789c..8ae9828c3fee1 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -66,6 +66,9 @@ private[spark] class SparkUI private (
addStaticHandler(SparkUI.STATIC_RESOURCE_DIR)
attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath))
attachHandler(ApiRootResource.getServletHandler(this))
+ if (sc.map(_.conf.get(UI_PROMETHEUS_ENABLED)).getOrElse(false)) {
+ attachHandler(PrometheusResource.getServletHandler(this))
+ }
// These should be POST only, but, the YARN AM proxy won't proxy POSTs
attachHandler(createRedirectHandler(
@@ -94,7 +97,7 @@ private[spark] class SparkUI private (
}
/** Stop the server behind this web interface. Only valid after bind(). */
- override def stop() {
+ override def stop(): Unit = {
super.stop()
logInfo(s"Stopped Spark web UI at $webUrl")
}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index 70e24bd0e7ecd..6dbe63b564e69 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -309,9 +309,13 @@ private[spark] object UIUtils extends Logging {
data: Iterable[T],
fixedWidth: Boolean = false,
id: Option[String] = None,
+ // When headerClasses is not empty, it should have the same length as headers parameter
headerClasses: Seq[String] = Seq.empty,
stripeRowsWithCss: Boolean = true,
- sortable: Boolean = true): Seq[Node] = {
+ sortable: Boolean = true,
+ // The tooltip information could be None, which indicates header does not have a tooltip.
+ // When tooltipHeaders is not empty, it should have the same length as headers parameter
+ tooltipHeaders: Seq[Option[String]] = Seq.empty): Seq[Node] = {
val listingTableClass = {
val _tableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED
@@ -332,6 +336,14 @@ private[spark] object UIUtils extends Logging {
}
}
+ def getTooltip(index: Int): Option[String] = {
+ if (index < tooltipHeaders.size) {
+ tooltipHeaders(index)
+ } else {
+ None
+ }
+ }
+
val newlinesInHeader = headers.exists(_.contains("\n"))
def getHeaderContent(header: String): Seq[Node] = {
if (newlinesInHeader) {
@@ -345,7 +357,15 @@ private[spark] object UIUtils extends Logging {
val headerRow: Seq[Node] = {
headers.view.zipWithIndex.map { x =>
- {getHeaderContent(x._1)}
+ getTooltip(x._2) match {
+ case Some(tooltip) =>
+
+
+ {getHeaderContent(x._1)}
+
+
+ case None => {getHeaderContent(x._1)}
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
index 8845dcf48a844..ca111a8d00a64 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
@@ -37,7 +37,7 @@ private[spark] object UIWorkloadGenerator {
val NUM_PARTITIONS = 100
val INTER_JOB_WAIT_MS = 5000
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 3) {
// scalastyle:off println
println(
@@ -98,7 +98,7 @@ private[spark] object UIWorkloadGenerator {
(1 to nJobSet).foreach { _ =>
for ((desc, job) <- jobs) {
new Thread {
- override def run() {
+ override def run(): Unit = {
// scalastyle:off println
try {
setProperties(desc)
diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
index 1fe822a0e3b57..9faa3dcf2cdf2 100644
--- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
@@ -184,7 +184,7 @@ private[spark] abstract class WebUITab(parent: WebUI, val prefix: String) {
val name = prefix.capitalize
/** Attach a page to this tab. This prepends the page's prefix with the tab's own prefix. */
- def attachPage(page: WebUIPage) {
+ def attachPage(page: WebUIPage): Unit = {
page.prefix = (prefix + "/" + page.prefix).stripSuffix("/")
pages += page
}
@@ -236,4 +236,8 @@ private[spark] class DelegatingServletContextHandler(handler: ServletContextHand
def filterCount(): Int = {
handler.getServletHandler.getFilters.length
}
+
+ def getContextPath(): String = {
+ handler.getContextPath
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
index eff0aa4453f08..0827395fea0bb 100644
--- a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
+++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
@@ -59,7 +59,7 @@ private[spark] class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Orderin
this += elem1 += elem2 ++= elems
}
- override def clear() { underlying.clear() }
+ override def clear(): Unit = { underlying.clear() }
private def maybeReplaceLowest(a: A): Boolean = {
val head = underlying.peek()
diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
index a5ee0ff16b5df..1383e1835028c 100644
--- a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
@@ -67,7 +67,7 @@ class ByteBufferInputStream(private var buffer: ByteBuffer)
/**
* Clean up the buffer, and potentially dispose of it using StorageUtils.dispose().
*/
- private def cleanUp() {
+ private def cleanUp(): Unit = {
if (buffer != null) {
buffer = null
}
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index 6d6ef5a744204..d2ad14f2a1a96 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -387,7 +387,7 @@ private[spark] object ClosureCleaner extends Logging {
}
}
- private def ensureSerializable(func: AnyRef) {
+ private def ensureSerializable(func: AnyRef): Unit = {
try {
if (SparkEnv.get != null) {
SparkEnv.get.closureSerializer.newInstance().serialize(func)
@@ -433,7 +433,7 @@ private class ReturnStatementFinder(targetMethodName: Option[String] = None)
name == targetMethodName.get || name == targetMethodName.get.stripSuffix("$adapted")
new MethodVisitor(ASM7) {
- override def visitTypeInsn(op: Int, tp: String) {
+ override def visitTypeInsn(op: Int, tp: String): Unit = {
if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl") && isTargetMethod) {
throw new ReturnStatementInClosureException
}
@@ -480,7 +480,7 @@ private[util] class FieldAccessFinder(
}
new MethodVisitor(ASM7) {
- override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
+ override def visitFieldInsn(op: Int, owner: String, name: String, desc: String): Unit = {
if (op == GETFIELD) {
for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
fields(cl) += name
@@ -489,7 +489,7 @@ private[util] class FieldAccessFinder(
}
override def visitMethodInsn(
- op: Int, owner: String, name: String, desc: String, itf: Boolean) {
+ op: Int, owner: String, name: String, desc: String, itf: Boolean): Unit = {
for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
// Check for calls a getter method for a variable in an interpreter wrapper object.
// This means that the corresponding field will be accessed, so we should save it.
@@ -528,7 +528,7 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM
// The second closure technically has two inner closures, but this finder only finds one
override def visit(version: Int, access: Int, name: String, sig: String,
- superName: String, interfaces: Array[String]) {
+ superName: String, interfaces: Array[String]): Unit = {
myName = name
}
@@ -536,7 +536,7 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM
sig: String, exceptions: Array[String]): MethodVisitor = {
new MethodVisitor(ASM7) {
override def visitMethodInsn(
- op: Int, owner: String, name: String, desc: String, itf: Boolean) {
+ op: Int, owner: String, name: String, desc: String, itf: Boolean): Unit = {
val argTypes = Type.getArgumentTypes(desc)
if (op == INVOKESPECIAL && name == "" && argTypes.length > 0
&& argTypes(0).toString.startsWith("L") // is it an object?
diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala
index 240dcfbab60ac..550884c873297 100644
--- a/core/src/main/scala/org/apache/spark/util/Distribution.scala
+++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala
@@ -65,7 +65,7 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va
* print a summary of this distribution to the given PrintStream.
* @param out
*/
- def summary(out: PrintStream = System.out) {
+ def summary(out: PrintStream = System.out): Unit = {
// scalastyle:off println
out.println(statCounter)
showQuantiles(out)
@@ -83,7 +83,7 @@ private[spark] object Distribution {
}
}
- def showQuantiles(out: PrintStream = System.out, quantiles: Iterable[Double]) {
+ def showQuantiles(out: PrintStream = System.out, quantiles: Iterable[Double]): Unit = {
// scalastyle:off println
out.println("min\t25%\t50%\t75%\tmax")
quantiles.foreach{q => out.print(q + "\t")}
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 73ef80980e73f..4d89c4f079f29 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -391,6 +391,7 @@ private[spark] object JsonProtocol {
("Executor Deserialize CPU Time" -> taskMetrics.executorDeserializeCpuTime) ~
("Executor Run Time" -> taskMetrics.executorRunTime) ~
("Executor CPU Time" -> taskMetrics.executorCpuTime) ~
+ ("Peak Execution Memory" -> taskMetrics.peakExecutionMemory) ~
("Result Size" -> taskMetrics.resultSize) ~
("JVM GC Time" -> taskMetrics.jvmGCTime) ~
("Result Serialization Time" -> taskMetrics.resultSerializationTime) ~
@@ -420,6 +421,7 @@ private[spark] object JsonProtocol {
("Block Manager Address" -> blockManagerAddress) ~
("Shuffle ID" -> fetchFailed.shuffleId) ~
("Map ID" -> fetchFailed.mapId) ~
+ ("Map Index" -> fetchFailed.mapIndex) ~
("Reduce ID" -> fetchFailed.reduceId) ~
("Message" -> fetchFailed.message)
case exceptionFailure: ExceptionFailure =>
@@ -893,6 +895,10 @@ private[spark] object JsonProtocol {
case JNothing => 0
case x => x.extract[Long]
})
+ metrics.setPeakExecutionMemory((json \ "Peak Execution Memory") match {
+ case JNothing => 0
+ case x => x.extract[Long]
+ })
metrics.setResultSize((json \ "Result Size").extract[Long])
metrics.setJvmGCTime((json \ "JVM GC Time").extract[Long])
metrics.setResultSerializationTime((json \ "Result Serialization Time").extract[Long])
@@ -974,10 +980,11 @@ private[spark] object JsonProtocol {
case `fetchFailed` =>
val blockManagerAddress = blockManagerIdFromJson(json \ "Block Manager Address")
val shuffleId = (json \ "Shuffle ID").extract[Int]
- val mapId = (json \ "Map ID").extract[Int]
+ val mapId = (json \ "Map ID").extract[Long]
+ val mapIndex = (json \ "Map Index").extract[Int]
val reduceId = (json \ "Reduce ID").extract[Int]
val message = jsonOption(json \ "Message").map(_.extract[String])
- new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId,
+ new FetchFailed(blockManagerAddress, shuffleId, mapId, mapIndex, reduceId,
message.getOrElse("Unknown reason"))
case `exceptionFailure` =>
val className = (json \ "Class Name").extract[String]
diff --git a/core/src/main/scala/org/apache/spark/util/NextIterator.scala b/core/src/main/scala/org/apache/spark/util/NextIterator.scala
index 0b505a576768c..0e289025da110 100644
--- a/core/src/main/scala/org/apache/spark/util/NextIterator.scala
+++ b/core/src/main/scala/org/apache/spark/util/NextIterator.scala
@@ -50,7 +50,7 @@ private[spark] abstract class NextIterator[U] extends Iterator[U] {
* Ideally you should have another try/catch, as in HadoopRDD, that
* ensures any resources are closed should iteration fail.
*/
- protected def close()
+ protected def close(): Unit
/**
* Calls the subclass-defined close method, but only once.
@@ -58,7 +58,7 @@ private[spark] abstract class NextIterator[U] extends Iterator[U] {
* Usually calling `close` multiple times should be fine, but historically
* there have been issues with some InputFormats throwing exceptions.
*/
- def closeIfNeeded() {
+ def closeIfNeeded(): Unit = {
if (!closed) {
// Note: it's important that we set closed = true before calling close(), since setting it
// afterwards would permit us to call close() multiple times if close() threw an exception.
diff --git a/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala
index c105f3229af09..f01645d82303e 100644
--- a/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala
+++ b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala
@@ -24,7 +24,6 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import org.apache.spark.internal.Logging
-import org.apache.spark.storage.StorageLevel
/**
diff --git a/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala
index 3354a923273ff..42d7f71404594 100644
--- a/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala
+++ b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala
@@ -20,7 +20,14 @@ import java.io.{ObjectInputStream, ObjectOutputStream}
import org.apache.hadoop.conf.Configuration
-private[spark]
+import org.apache.spark.annotation.{DeveloperApi, Unstable}
+
+/**
+ * Hadoop configuration but serializable. Use `value` to access the Hadoop configuration.
+ *
+ * @param value Hadoop configuration
+ */
+@DeveloperApi @Unstable
class SerializableConfiguration(@transient var value: Configuration) extends Serializable {
private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
out.defaultWriteObject()
diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala
index b702838fa257f..4f1311224bb95 100644
--- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala
+++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala
@@ -70,7 +70,7 @@ private[spark] object ShutdownHookManager extends Logging {
}
// Register the path to be deleted via shutdown hook
- def registerShutdownDeleteDir(file: File) {
+ def registerShutdownDeleteDir(file: File): Unit = {
val absolutePath = file.getAbsolutePath()
shutdownDeletePaths.synchronized {
shutdownDeletePaths += absolutePath
@@ -78,7 +78,7 @@ private[spark] object ShutdownHookManager extends Logging {
}
// Remove the path to be deleted via shutdown hook
- def removeShutdownDeleteDir(file: File) {
+ def removeShutdownDeleteDir(file: File): Unit = {
val absolutePath = file.getAbsolutePath()
shutdownDeletePaths.synchronized {
shutdownDeletePaths.remove(absolutePath)
@@ -120,7 +120,7 @@ private[spark] object ShutdownHookManager extends Logging {
def inShutdown(): Boolean = {
try {
val hook = new Thread {
- override def run() {}
+ override def run(): Unit = {}
}
// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(hook)
diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
index 09c69f5c68b03..2caf2a36f9dc6 100644
--- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
+++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
@@ -107,7 +107,7 @@ object SizeEstimator extends Logging {
// Sets object size, pointer size based on architecture and CompressedOops settings
// from the JVM.
- private def initialize() {
+ private def initialize(): Unit = {
val arch = System.getProperty("os.arch")
is64bit = arch.contains("64") || arch.contains("s390x")
isCompressedOops = getIsCompressedOops
@@ -171,7 +171,7 @@ object SizeEstimator extends Logging {
val stack = new ArrayBuffer[AnyRef]
var size = 0L
- def enqueue(obj: AnyRef) {
+ def enqueue(obj: AnyRef): Unit = {
if (obj != null && !visited.containsKey(obj)) {
visited.put(obj, null)
stack += obj
@@ -205,7 +205,7 @@ object SizeEstimator extends Logging {
state.size
}
- private def visitSingleObject(obj: AnyRef, state: SearchState) {
+ private def visitSingleObject(obj: AnyRef, state: SearchState): Unit = {
val cls = obj.getClass
if (cls.isArray) {
visitArray(obj, cls, state)
@@ -234,7 +234,7 @@ object SizeEstimator extends Logging {
private val ARRAY_SIZE_FOR_SAMPLING = 400
private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING
- private def visitArray(array: AnyRef, arrayClass: Class[_], state: SearchState) {
+ private def visitArray(array: AnyRef, arrayClass: Class[_], state: SearchState): Unit = {
val length = ScalaRunTime.array_length(array)
val elementClass = arrayClass.getComponentType()
diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
index 1b34fbde38cd6..2550634681453 100644
--- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
+++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
@@ -28,7 +28,7 @@ import org.apache.spark.internal.Logging
private[spark] class SparkUncaughtExceptionHandler(val exitOnUncaughtException: Boolean = true)
extends Thread.UncaughtExceptionHandler with Logging {
- override def uncaughtException(thread: Thread, exception: Throwable) {
+ override def uncaughtException(thread: Thread, exception: Throwable): Unit = {
try {
// Make it explicit that uncaught exceptions are thrown when container is shutting down.
// It will help users when they analyze the executor logs
@@ -56,7 +56,7 @@ private[spark] class SparkUncaughtExceptionHandler(val exitOnUncaughtException:
}
}
- def uncaughtException(exception: Throwable) {
+ def uncaughtException(exception: Throwable): Unit = {
uncaughtException(Thread.currentThread(), exception)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
index 32af0127bbf38..550e0674a14e0 100644
--- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
@@ -81,7 +81,7 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa
this
}
- override def update(key: A, value: B) {
+ override def update(key: A, value: B): Unit = {
this += ((key, value))
}
@@ -97,7 +97,7 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa
override def size: Int = internalMap.size
- override def foreach[U](f: ((A, B)) => U) {
+ override def foreach[U](f: ((A, B)) => U): Unit = {
val it = getEntrySet.iterator
while(it.hasNext) {
val entry = it.next()
@@ -111,13 +111,13 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa
Option(prev).map(_.value)
}
- def putAll(map: Map[A, B]) {
+ def putAll(map: Map[A, B]): Unit = {
map.foreach { case (k, v) => update(k, v) }
}
def toMap: Map[A, B] = iterator.toMap
- def clearOldValues(threshTime: Long, f: (A, B) => Unit) {
+ def clearOldValues(threshTime: Long, f: (A, B) => Unit): Unit = {
val it = getEntrySet.iterator
while (it.hasNext) {
val entry = it.next()
@@ -130,7 +130,7 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa
}
/** Removes old key-value pairs that have timestamp earlier than `threshTime`. */
- def clearOldValues(threshTime: Long) {
+ def clearOldValues(threshTime: Long): Unit = {
clearOldValues(threshTime, (_, _) => ())
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 9c1f21fa236ba..f853ec8368366 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -999,7 +999,7 @@ private[spark] object Utils extends Logging {
* Allow setting a custom host name because when we run on Mesos we need to use the same
* hostname it reports to the master.
*/
- def setCustomHostname(hostname: String) {
+ def setCustomHostname(hostname: String): Unit = {
// DEBUG code
Utils.checkHost(hostname)
customHostname = Some(hostname)
@@ -1026,11 +1026,11 @@ private[spark] object Utils extends Logging {
customHostname.getOrElse(InetAddresses.toUriString(localIpAddress))
}
- def checkHost(host: String) {
+ def checkHost(host: String): Unit = {
assert(host != null && host.indexOf(':') == -1, s"Expected hostname (not IP) but got $host")
}
- def checkHostPort(hostPort: String) {
+ def checkHostPort(hostPort: String): Unit = {
assert(hostPort != null && hostPort.indexOf(':') != -1,
s"Expected host and port but got $hostPort")
}
@@ -1280,7 +1280,7 @@ private[spark] object Utils extends Logging {
inputStream: InputStream,
processLine: String => Unit): Thread = {
val t = new Thread(threadName) {
- override def run() {
+ override def run(): Unit = {
for (line <- Source.fromInputStream(inputStream).getLines()) {
processLine(line)
}
@@ -1297,7 +1297,7 @@ private[spark] object Utils extends Logging {
*
* NOTE: This method is to be called by the spark-started JVM process.
*/
- def tryOrExit(block: => Unit) {
+ def tryOrExit(block: => Unit): Unit = {
try {
block
} catch {
@@ -1314,7 +1314,7 @@ private[spark] object Utils extends Logging {
* user-started JVM process completely; in contrast, tryOrExit is to be called in the
* spark-started JVM process .
*/
- def tryOrStopSparkContext(sc: SparkContext)(block: => Unit) {
+ def tryOrStopSparkContext(sc: SparkContext)(block: => Unit): Unit = {
try {
block
} catch {
@@ -1352,7 +1352,7 @@ private[spark] object Utils extends Logging {
}
/** Executes the given block. Log non-fatal errors if any, and only throw fatal errors */
- def tryLogNonFatalError(block: => Unit) {
+ def tryLogNonFatalError(block: => Unit): Unit = {
try {
block
} catch {
@@ -1671,7 +1671,7 @@ private[spark] object Utils extends Logging {
var inSingleQuote = false
var inDoubleQuote = false
val curWord = new StringBuilder
- def endWord() {
+ def endWord(): Unit = {
buf += curWord.toString
curWord.clear()
}
@@ -2342,7 +2342,7 @@ private[spark] object Utils extends Logging {
/**
* configure a new log4j level
*/
- def setLogLevel(l: org.apache.log4j.Level) {
+ def setLogLevel(l: org.apache.log4j.Level): Unit = {
val rootLogger = org.apache.log4j.Logger.getRootLogger()
rootLogger.setLevel(l)
// Setting threshold to null as rootLevel will define log level for spark-shell
@@ -2950,6 +2950,13 @@ private[spark] object Utils extends Logging {
val codec = codecFactory.getCodec(path)
codec == null || codec.isInstanceOf[SplittableCompressionCodec]
}
+
+ /** Create a new properties object with the same values as `props` */
+ def cloneProperties(props: Properties): Properties = {
+ val resultProps = new Properties()
+ props.forEach((k, v) => resultProps.put(k, v))
+ resultProps
+ }
}
private[util] object CallerContext extends Logging {
@@ -3033,7 +3040,8 @@ private[spark] class CallerContext(
if (CallerContext.callerContextSupported) {
try {
val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext")
- val builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
+ val builder: Class[AnyRef] =
+ Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
val builderInst = builder.getConstructor(classOf[String]).newInstance(context)
val hdfsContext = builder.getMethod("build").invoke(builderInst)
callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext)
@@ -3056,7 +3064,7 @@ private[spark] class RedirectThread(
extends Thread(name) {
setDaemon(true)
- override def run() {
+ override def run(): Unit = {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME: We copy the stream on the level of bytes to avoid encoding problems.
Utils.tryWithSafeFinally {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
index bcb95b416dd25..46e311d8b0476 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
@@ -198,7 +198,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64)
override def size: Int = curSize
/** Increase table size by 1, rehashing if necessary */
- private def incrementSize() {
+ private def incrementSize(): Unit = {
curSize += 1
if (curSize > growThreshold) {
growTable()
@@ -211,7 +211,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64)
private def rehash(h: Int): Int = Hashing.murmur3_32().hashInt(h).asInt()
/** Double the table's size and re-hash everything */
- protected def growTable() {
+ protected def growTable(): Unit = {
// capacity < MAXIMUM_CAPACITY (2 ^ 29) so capacity * 2 won't overflow
val newCapacity = capacity * 2
require(newCapacity <= MAXIMUM_CAPACITY, s"Can't contain more than ${growThreshold} elements")
diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
index e63e0e3e1f68f..098f389829ec5 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
@@ -150,12 +150,12 @@ class BitSet(numBits: Int) extends Serializable {
* Sets the bit at the specified index to true.
* @param index the bit index
*/
- def set(index: Int) {
+ def set(index: Int): Unit = {
val bitmask = 1L << (index & 0x3f) // mod 64 and shift
words(index >> 6) |= bitmask // div by 64 and mask
}
- def unset(index: Int) {
+ def unset(index: Int): Unit = {
val bitmask = 1L << (index & 0x3f) // mod 64 and shift
words(index >> 6) &= ~bitmask // div by 64 and mask
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 1ba3b7875f8dc..14409c3661baa 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -549,7 +549,7 @@ class ExternalAppendOnlyMap[K, V, C](
item
}
- private def cleanup() {
+ private def cleanup(): Unit = {
batchIndex = batchOffsets.length // Prevent reading any other batch
if (deserializeStream != null) {
deserializeStream.close()
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 7a822e137e556..cc97bbfa7201f 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -23,7 +23,7 @@ import java.util.Comparator
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import com.google.common.io.{ByteStreams, Closeables}
+import com.google.common.io.ByteStreams
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
@@ -534,7 +534,7 @@ private[spark] class ExternalSorter[K, V, C](
* Update partitionId if we have reached the end of our current partition, possibly skipping
* empty partitions on the way.
*/
- private def skipToNextPartition() {
+ private def skipToNextPartition(): Unit = {
while (partitionId < numPartitions &&
indexInPartition == spill.elementsPerPartition(partitionId)) {
partitionId += 1
@@ -605,7 +605,7 @@ private[spark] class ExternalSorter[K, V, C](
}
// Clean up our open streams and put us in a state where we can't read any more data
- def cleanup() {
+ def cleanup(): Unit = {
batchId = batchOffsets.length // Prevent reading any other batch
val ds = deserializeStream
deserializeStream = null
@@ -727,7 +727,7 @@ private[spark] class ExternalSorter[K, V, C](
*/
def writePartitionedMapOutput(
shuffleId: Int,
- mapId: Int,
+ mapId: Long,
mapOutputWriter: ShuffleMapOutputWriter): Unit = {
var nextPartitionId = 0
if (spills.isEmpty) {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
index 10ab0b3f89964..1200ac001cce7 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
@@ -76,7 +76,7 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
}
/** Set the value for a key */
- def update(k: K, v: V) {
+ def update(k: K, v: V): Unit = {
if (k == null) {
haveNullValue = true
nullValue = v
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
index 8883e17bf3164..6815e47a198d9 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -113,7 +113,7 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
* Add an element to the set. If the set is over capacity after the insertion, grow the set
* and rehash all elements.
*/
- def add(k: T) {
+ def add(k: T): Unit = {
addWithoutResize(k)
rehashIfNeeded(k, grow, move)
}
@@ -166,7 +166,7 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
* @param moveFunc Callback invoked when we move the key from one position (in the old data array)
* to a new position (in the new data array).
*/
- def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
+ def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit): Unit = {
if (_size > _growThreshold) {
rehash(k, allocateFunc, moveFunc)
}
@@ -227,7 +227,7 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
* @param moveFunc Callback invoked when we move the key from one position (in the old data array)
* to a new position (in the new data array).
*/
- private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
+ private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit): Unit = {
val newCapacity = _capacity * 2
require(newCapacity > 0 && newCapacity <= OpenHashSet.MAX_CAPACITY,
s"Can't contain more than ${(loadFactor * OpenHashSet.MAX_CAPACITY).toInt} elements")
@@ -320,8 +320,8 @@ object OpenHashSet {
override def hash(o: Float): Int = java.lang.Float.floatToIntBits(o)
}
- private def grow1(newSize: Int) {}
- private def move1(oldPos: Int, newPos: Int) { }
+ private def grow1(newSize: Int): Unit = {}
+ private def move1(oldPos: Int, newPos: Int): Unit = { }
private val grow = grow1 _
private val move = move1 _
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
index b4ec4ea521253..7a50d851941ee 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -66,7 +66,7 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
}
/** Set the value for a key */
- def update(k: K, v: V) {
+ def update(k: K, v: V): Unit = {
val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
_values(pos) = v
_keySet.rehashIfNeeded(k, grow, move)
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala
index 9a7a5a4e74868..582bd124b5116 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala
@@ -87,7 +87,7 @@ class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K,
override def getKey(data: Array[T], pos: Int): K = data(2 * pos).asInstanceOf[K]
- override def swap(data: Array[T], pos0: Int, pos1: Int) {
+ override def swap(data: Array[T], pos0: Int, pos1: Int): Unit = {
val tmpKey = data(2 * pos0)
val tmpVal = data(2 * pos0 + 1)
data(2 * pos0) = data(2 * pos1)
@@ -96,12 +96,13 @@ class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K,
data(2 * pos1 + 1) = tmpVal
}
- override def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int) {
+ override def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int): Unit = {
dst(2 * dstPos) = src(2 * srcPos)
dst(2 * dstPos + 1) = src(2 * srcPos + 1)
}
- override def copyRange(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int, length: Int) {
+ override def copyRange(src: Array[T], srcPos: Int,
+ dst: Array[T], dstPos: Int, length: Int): Unit = {
System.arraycopy(src, 2 * srcPos, dst, 2 * dstPos, 2 * length)
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
index bfc0face5d8e5..1983b0002853d 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -141,7 +141,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager)
*
* @param size number of bytes spilled
*/
- @inline private def logSpillage(size: Long) {
+ @inline private def logSpillage(size: Long): Unit = {
val threadId = Thread.currentThread().getId
logInfo("Thread %d spilling in-memory map of %s to disk (%d time%s so far)"
.format(threadId, org.apache.spark.util.Utils.bytesToString(size),
diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
index da8d58d05b6b9..9624b02cb407c 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
@@ -19,7 +19,6 @@ package org.apache.spark.util.collection
import java.util.Comparator
-import org.apache.spark.storage.DiskBlockObjectWriter
/**
* A common interface for size-tracking collections of key-value pairs that
diff --git a/core/src/main/scala/org/apache/spark/util/logging/DriverLogger.scala b/core/src/main/scala/org/apache/spark/util/logging/DriverLogger.scala
index c4540433bce97..4c1b49762ace3 100644
--- a/core/src/main/scala/org/apache/spark/util/logging/DriverLogger.scala
+++ b/core/src/main/scala/org/apache/spark/util/logging/DriverLogger.scala
@@ -18,15 +18,18 @@
package org.apache.spark.util.logging
import java.io._
+import java.util.EnumSet
import java.util.concurrent.{ScheduledExecutorService, TimeUnit}
import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path}
import org.apache.hadoop.fs.permission.FsPermission
+import org.apache.hadoop.hdfs.client.HdfsDataOutputStream
import org.apache.log4j.{FileAppender => Log4jFileAppender, _}
import org.apache.spark.SparkConf
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.util.JavaUtils
@@ -111,7 +114,8 @@ private[spark] class DriverLogger(conf: SparkConf) extends Logging {
+ DriverLogger.DRIVER_LOG_FILE_SUFFIX).getAbsolutePath()
try {
inStream = new BufferedInputStream(new FileInputStream(localLogFile))
- outputStream = fileSystem.create(new Path(dfsLogFile), true)
+ outputStream = SparkHadoopUtil.createFile(fileSystem, new Path(dfsLogFile),
+ conf.get(DRIVER_LOG_ALLOW_EC))
fileSystem.setPermission(new Path(dfsLogFile), LOG_FILE_PERMISSIONS)
} catch {
case e: Exception =>
@@ -131,12 +135,20 @@ private[spark] class DriverLogger(conf: SparkConf) extends Logging {
}
try {
var remaining = inStream.available()
+ val hadData = remaining > 0
while (remaining > 0) {
val read = inStream.read(tmpBuffer, 0, math.min(remaining, UPLOAD_CHUNK_SIZE))
outputStream.write(tmpBuffer, 0, read)
remaining -= read
}
- outputStream.hflush()
+ if (hadData) {
+ outputStream match {
+ case hdfsStream: HdfsDataOutputStream =>
+ hdfsStream.hsync(EnumSet.allOf(classOf[HdfsDataOutputStream.SyncFlag]))
+ case other =>
+ other.hflush()
+ }
+ }
} catch {
case e: Exception => logError("Failed writing driver logs to dfs", e)
}
diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala
index 3188e0bd2b70d..7107be25eb505 100644
--- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala
+++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala
@@ -34,7 +34,7 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi
// Thread that reads the input stream and writes to file
private val writingThread = new Thread("File appending thread for " + file) {
setDaemon(true)
- override def run() {
+ override def run(): Unit = {
Utils.logUncaughtExceptions {
appendStreamToFile()
}
@@ -46,17 +46,17 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi
* Wait for the appender to stop appending, either because input stream is closed
* or because of any error in appending
*/
- def awaitTermination() {
+ def awaitTermination(): Unit = {
writingThread.join()
}
/** Stop the appender */
- def stop() {
+ def stop(): Unit = {
markedForStop = true
}
/** Continuously read chunks from the input stream and append to the file */
- protected def appendStreamToFile() {
+ protected def appendStreamToFile(): Unit = {
try {
logDebug("Started appending thread")
Utils.tryWithSafeFinally {
@@ -85,7 +85,7 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi
}
/** Append bytes to the file output stream */
- protected def appendToFile(bytes: Array[Byte], len: Int) {
+ protected def appendToFile(bytes: Array[Byte], len: Int): Unit = {
if (outputStream == null) {
openFile()
}
@@ -93,13 +93,13 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi
}
/** Open the file output stream */
- protected def openFile() {
+ protected def openFile(): Unit = {
outputStream = new FileOutputStream(file, true)
logDebug(s"Opened file $file")
}
/** Close the file output stream */
- protected def closeFile() {
+ protected def closeFile(): Unit = {
outputStream.flush()
outputStream.close()
logDebug(s"Closed file $file")
diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala
index 59439b68792e5..b73f422649312 100644
--- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala
+++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala
@@ -49,12 +49,12 @@ private[spark] class RollingFileAppender(
private val enableCompression = conf.get(config.EXECUTOR_LOGS_ROLLING_ENABLE_COMPRESSION)
/** Stop the appender */
- override def stop() {
+ override def stop(): Unit = {
super.stop()
}
/** Append bytes to file after rolling over is necessary */
- override protected def appendToFile(bytes: Array[Byte], len: Int) {
+ override protected def appendToFile(bytes: Array[Byte], len: Int): Unit = {
if (rollingPolicy.shouldRollover(len)) {
rollover()
rollingPolicy.rolledOver()
@@ -64,7 +64,7 @@ private[spark] class RollingFileAppender(
}
/** Rollover the file, by closing the output stream and moving it over */
- private def rollover() {
+ private def rollover(): Unit = {
try {
closeFile()
moveFile()
@@ -106,7 +106,7 @@ private[spark] class RollingFileAppender(
}
/** Move the active log file to a new rollover file */
- private def moveFile() {
+ private def moveFile(): Unit = {
val rolloverSuffix = rollingPolicy.generateRolledOverFileSuffix()
val rolloverFile = new File(
activeFile.getParentFile, activeFile.getName + rolloverSuffix).getAbsoluteFile
@@ -138,7 +138,7 @@ private[spark] class RollingFileAppender(
}
/** Retain only last few files */
- private[util] def deleteOldFiles() {
+ private[util] def deleteOldFiles(): Unit = {
try {
val rolledoverFiles = activeFile.getParentFile.listFiles(new FileFilter {
def accept(f: File): Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala
index 1f263df57c857..5327ecd3e56a9 100644
--- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala
+++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala
@@ -67,12 +67,12 @@ private[spark] class TimeBasedRollingPolicy(
}
/** Rollover has occurred, so find the next time to rollover */
- def rolledOver() {
+ def rolledOver(): Unit = {
nextRolloverTime = calculateNextRolloverTime()
logDebug(s"Current time: ${System.currentTimeMillis}, next rollover time: " + nextRolloverTime)
}
- def bytesWritten(bytes: Long) { } // nothing to do
+ def bytesWritten(bytes: Long): Unit = { } // nothing to do
private def calculateNextRolloverTime(): Long = {
val now = System.currentTimeMillis()
@@ -118,12 +118,12 @@ private[spark] class SizeBasedRollingPolicy(
}
/** Rollover has occurred, so reset the counter */
- def rolledOver() {
+ def rolledOver(): Unit = {
bytesWrittenSinceRollover = 0
}
/** Increment the bytes that have been written in the current file */
- def bytesWritten(bytes: Long) {
+ def bytesWritten(bytes: Long): Unit = {
bytesWrittenSinceRollover += bytes
}
diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 70554f1d03067..6dd2beebbb3dc 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -201,7 +201,7 @@ class PoissonSampler[T](
private val rng = new PoissonDistribution(if (fraction > 0.0) fraction else 1.0)
private val rngGap = RandomSampler.newDefaultRNG
- override def setSeed(seed: Long) {
+ override def setSeed(seed: Long): Unit = {
rng.reseedRandomGenerator(seed)
rngGap.setSeed(seed)
}
diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
index af09e50a157ae..313569a81646d 100644
--- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
@@ -49,7 +49,7 @@ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) {
(nextSeed & ((1L << bits) -1)).asInstanceOf[Int]
}
- override def setSeed(s: Long) {
+ override def setSeed(s: Long): Unit = {
seed = XORShiftRandom.hashSeed(s)
}
}
@@ -60,7 +60,7 @@ private[spark] object XORShiftRandom {
/** Hash seeds to have 0/1 bits throughout. */
private[random] def hashSeed(seed: Long): Long = {
val bytes = ByteBuffer.allocate(java.lang.Long.BYTES).putLong(seed).array()
- val lowBits = MurmurHash3.bytesHash(bytes)
+ val lowBits = MurmurHash3.bytesHash(bytes, MurmurHash3.arraySeed)
val highBits = MurmurHash3.bytesHash(bytes, lowBits)
(highBits.toLong << 32) | (lowBits.toLong & 0xFFFFFFFFL)
}
diff --git a/core/src/test/java/org/apache/spark/ExecutorPluginSuite.java b/core/src/test/java/org/apache/spark/ExecutorPluginSuite.java
index 80cd70282a51d..ade13f02bde73 100644
--- a/core/src/test/java/org/apache/spark/ExecutorPluginSuite.java
+++ b/core/src/test/java/org/apache/spark/ExecutorPluginSuite.java
@@ -17,6 +17,8 @@
package org.apache.spark;
+import com.codahale.metrics.Gauge;
+import com.codahale.metrics.MetricRegistry;
import org.apache.spark.api.java.JavaSparkContext;
import org.junit.After;
@@ -30,6 +32,7 @@ public class ExecutorPluginSuite {
private static final String testBadPluginName = TestBadShutdownPlugin.class.getName();
private static final String testPluginName = TestExecutorPlugin.class.getName();
private static final String testSecondPluginName = TestSecondPlugin.class.getName();
+ private static final String testMetricsPluginName = TestMetricsPlugin.class.getName();
// Static value modified by testing plugins to ensure plugins loaded correctly.
public static int numSuccessfulPlugins = 0;
@@ -37,6 +40,10 @@ public class ExecutorPluginSuite {
// Static value modified by testing plugins to verify plugins shut down properly.
public static int numSuccessfulTerminations = 0;
+ // Static values modified by testing plugins to ensure metrics have been registered correctly.
+ public static MetricRegistry testMetricRegistry;
+ public static String gaugeName;
+
private JavaSparkContext sc;
@Before
@@ -107,8 +114,21 @@ public void testPluginShutdownWithException() {
assertEquals(2, numSuccessfulTerminations);
}
+ @Test
+ public void testPluginMetrics() {
+ // Verify that a custom metric is registered with the Spark metrics system
+ gaugeName = "test42";
+ SparkConf conf = initializeSparkConf(testMetricsPluginName);
+ sc = new JavaSparkContext(conf);
+ assertEquals(1, numSuccessfulPlugins);
+ assertEquals(gaugeName, testMetricRegistry.getGauges().firstKey());
+ sc.stop();
+ sc = null;
+ assertEquals(1, numSuccessfulTerminations);
+ }
+
public static class TestExecutorPlugin implements ExecutorPlugin {
- public void init() {
+ public void init(ExecutorPluginContext pluginContext) {
ExecutorPluginSuite.numSuccessfulPlugins++;
}
@@ -118,7 +138,7 @@ public void shutdown() {
}
public static class TestSecondPlugin implements ExecutorPlugin {
- public void init() {
+ public void init(ExecutorPluginContext pluginContext) {
ExecutorPluginSuite.numSuccessfulPlugins++;
}
@@ -128,7 +148,7 @@ public void shutdown() {
}
public static class TestBadShutdownPlugin implements ExecutorPlugin {
- public void init() {
+ public void init(ExecutorPluginContext pluginContext) {
ExecutorPluginSuite.numSuccessfulPlugins++;
}
@@ -136,4 +156,24 @@ public void shutdown() {
throw new RuntimeException("This plugin will fail to cleanly shut down");
}
}
+
+ public static class TestMetricsPlugin implements ExecutorPlugin {
+ public void init(ExecutorPluginContext myContext) {
+ MetricRegistry metricRegistry = myContext.metricRegistry;
+ // Registers a dummy metrics gauge for testing
+ String gaugeName = ExecutorPluginSuite.gaugeName;
+ metricRegistry.register(MetricRegistry.name(gaugeName), new Gauge() {
+ @Override
+ public Integer getValue() {
+ return 42;
+ }
+ });
+ ExecutorPluginSuite.testMetricRegistry = metricRegistry;
+ ExecutorPluginSuite.numSuccessfulPlugins++;
+ }
+
+ public void shutdown() {
+ ExecutorPluginSuite.numSuccessfulTerminations++;
+ }
+ }
}
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 6b83a984f037c..10e6936eb3799 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -19,8 +19,10 @@
import java.io.*;
import java.nio.ByteBuffer;
+import java.nio.file.Files;
import java.util.*;
+import org.mockito.stubbing.Answer;
import scala.Option;
import scala.Product2;
import scala.Tuple2;
@@ -53,6 +55,7 @@
import org.apache.spark.security.CryptoStreamUtils;
import org.apache.spark.serializer.*;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;
@@ -65,6 +68,7 @@
public class UnsafeShuffleWriterSuite {
+ static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096;
static final int NUM_PARTITITONS = 4;
TestMemoryManager memoryManager;
TaskMemoryManager taskMemoryManager;
@@ -131,15 +135,29 @@ public void setUp() throws IOException {
);
});
- when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile);
- doAnswer(invocationOnMock -> {
+ when(shuffleBlockResolver.getDataFile(anyInt(), anyLong())).thenReturn(mergedOutputFile);
+
+ Answer> renameTempAnswer = invocationOnMock -> {
partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
File tmp = (File) invocationOnMock.getArguments()[3];
- mergedOutputFile.delete();
- tmp.renameTo(mergedOutputFile);
+ if (!mergedOutputFile.delete()) {
+ throw new RuntimeException("Failed to delete old merged output file.");
+ }
+ if (tmp != null) {
+ Files.move(tmp.toPath(), mergedOutputFile.toPath());
+ } else if (!mergedOutputFile.createNewFile()) {
+ throw new RuntimeException("Failed to create empty merged output file.");
+ }
return null;
- }).when(shuffleBlockResolver)
- .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class));
+ };
+
+ doAnswer(renameTempAnswer)
+ .when(shuffleBlockResolver)
+ .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), any(File.class));
+
+ doAnswer(renameTempAnswer)
+ .when(shuffleBlockResolver)
+ .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), eq(null));
when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> {
TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
@@ -151,21 +169,20 @@ public void setUp() throws IOException {
when(taskContext.taskMetrics()).thenReturn(taskMetrics);
when(shuffleDep.serializer()).thenReturn(serializer);
when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
+ when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager);
}
- private UnsafeShuffleWriter createWriter(
- boolean transferToEnabled) throws IOException {
+ private UnsafeShuffleWriter createWriter(boolean transferToEnabled) {
conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
return new UnsafeShuffleWriter<>(
blockManager,
- shuffleBlockResolver,
taskMemoryManager,
- new SerializedShuffleHandle<>(0, 1, shuffleDep),
- 0, // map id
+ new SerializedShuffleHandle<>(0, shuffleDep),
+ 0L, // map id
taskContext,
conf,
- taskContext.taskMetrics().shuffleWriteMetrics()
- );
+ taskContext.taskMetrics().shuffleWriteMetrics(),
+ new LocalDiskShuffleExecutorComponents(conf, blockManager, shuffleBlockResolver));
}
private void assertSpillFilesWereCleanedUp() {
@@ -391,7 +408,7 @@ public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Except
@Test
public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception {
- conf.set(package$.MODULE$.SHUFFLE_UNDAFE_FAST_MERGE_ENABLE(), false);
+ conf.set(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE(), false);
testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
}
@@ -444,10 +461,10 @@ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOn() thro
}
private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
- memoryManager.limit(UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16);
+ memoryManager.limit(DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16);
final UnsafeShuffleWriter writer = createWriter(false);
final ArrayList> dataToWrite = new ArrayList<>();
- for (int i = 0; i < UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) {
+ for (int i = 0; i < DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) {
dataToWrite.add(new Tuple2<>(i, i));
}
writer.write(dataToWrite.iterator());
@@ -516,16 +533,15 @@ public void testPeakMemoryUsed() throws Exception {
final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
taskMemoryManager = spy(taskMemoryManager);
when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
- final UnsafeShuffleWriter writer =
- new UnsafeShuffleWriter<>(
+ final UnsafeShuffleWriter writer = new UnsafeShuffleWriter<>(
blockManager,
- shuffleBlockResolver,
taskMemoryManager,
- new SerializedShuffleHandle<>(0, 1, shuffleDep),
- 0, // map id
+ new SerializedShuffleHandle<>(0, shuffleDep),
+ 0L, // map id
taskContext,
conf,
- taskContext.taskMetrics().shuffleWriteMetrics());
+ taskContext.taskMetrics().shuffleWriteMetrics(),
+ new LocalDiskShuffleExecutorComponents(conf, blockManager, shuffleBlockResolver));
// Peak memory should be monotonically increasing. More specifically, every time
// we allocate a new page it should increase by exactly the size of the page.
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 8d03c6778e18b..6e995a3929a75 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -34,6 +34,7 @@
import org.apache.spark.SparkConf;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.memory.SparkOutOfMemoryError;
import org.apache.spark.memory.TestMemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.memory.TestMemoryManager;
@@ -691,13 +692,11 @@ public void avoidDeadlock() throws InterruptedException {
Thread thread = new Thread(() -> {
int i = 0;
- long used = 0;
while (i < 10) {
c1.use(10000000);
- used += 10000000;
i++;
}
- c1.free(used);
+ c1.free(c1.getUsed());
});
try {
@@ -726,4 +725,22 @@ public void avoidDeadlock() throws InterruptedException {
}
}
+ @Test
+ public void freeAfterFailedReset() {
+ // SPARK-29244: BytesToBytesMap.free after a OOM reset operation should not cause failure.
+ memoryManager.limit(5000);
+ BytesToBytesMap map =
+ new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 256, 0.5, 4000);
+ // Force OOM on next memory allocation.
+ memoryManager.markExecutionAsOutOfMemoryOnce();
+ try {
+ map.reset();
+ Assert.fail("Expected SparkOutOfMemoryError to be thrown");
+ } catch (SparkOutOfMemoryError e) {
+ // Expected exception; do nothing.
+ } finally {
+ map.free();
+ }
+ }
+
}
diff --git a/core/src/test/java/org/apache/spark/util/SerializableConfigurationSuite.java b/core/src/test/java/org/apache/spark/util/SerializableConfigurationSuite.java
new file mode 100644
index 0000000000000..0944d681599a1
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/SerializableConfigurationSuite.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.util;
+
+import java.util.Arrays;
+
+import org.apache.hadoop.conf.Configuration;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+
+import static org.junit.Assert.assertEquals;
+
+
+public class SerializableConfigurationSuite {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "SerializableConfigurationSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void testSerializableConfiguration() {
+ JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
+ Configuration hadoopConfiguration = new Configuration(false);
+ hadoopConfiguration.set("test.property", "value");
+ SerializableConfiguration scs = new SerializableConfiguration(hadoopConfiguration);
+ SerializableConfiguration actual = rdd.map(val -> scs).collect().get(0);
+ assertEquals(actual.value().get("test.property"), "value");
+ }
+}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index c6aa623560d57..d5b1a1c5f547d 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -235,6 +235,9 @@ public void testSortTimeMetric() throws Exception {
sorter.insertRecord(null, 0, 0, 0, false);
UnsafeSorterIterator iter = sorter.getSortedIterator();
assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime));
+
+ sorter.cleanupResources();
+ assertSpillFilesWereCleanedUp();
}
@Test
@@ -510,6 +513,8 @@ public void testGetIterator() throws Exception {
verifyIntIterator(sorter.getIterator(79), 79, 300);
verifyIntIterator(sorter.getIterator(139), 139, 300);
verifyIntIterator(sorter.getIterator(279), 279, 300);
+ sorter.cleanupResources();
+ assertSpillFilesWereCleanedUp();
}
@Test
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 435665d8a1ce2..a75cf3f0381df 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -126,7 +126,7 @@ private[spark] object AccumulatorSuite {
sc.addSparkListener(listener)
testBody
// wait until all events have been processed before proceeding to assert things
- sc.listenerBus.waitUntilEmpty(10 * 1000)
+ sc.listenerBus.waitUntilEmpty()
val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values)
val isSet = accums.exists { a =>
a.name == Some(PEAK_EXECUTION_MEMORY) && a.value.exists(_.asInstanceOf[Long] > 0L)
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index 6a30a1d32f8c6..92ed24408384f 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -97,7 +97,7 @@ abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[So
}
/** Run GC and make sure it actually has run */
- protected def runGC() {
+ protected def runGC(): Unit = {
val weakRef = new WeakReference(new Object())
val startTimeNs = System.nanoTime()
System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
@@ -406,7 +406,7 @@ class CleanerTester(
sc.cleaner.get.attachListener(cleanerListener)
/** Assert that all the stuff has been cleaned up */
- def assertCleanup()(implicit waitTimeout: PatienceConfiguration.Timeout) {
+ def assertCleanup()(implicit waitTimeout: PatienceConfiguration.Timeout): Unit = {
try {
eventually(waitTimeout, interval(100.milliseconds)) {
assert(isAllCleanedUp,
@@ -419,7 +419,7 @@ class CleanerTester(
}
/** Verify that RDDs, shuffles, etc. occupy resources */
- private def preCleanupValidate() {
+ private def preCleanupValidate(): Unit = {
assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty ||
checkpointIds.nonEmpty, "Nothing to cleanup")
@@ -465,7 +465,7 @@ class CleanerTester(
* Verify that RDDs, shuffles, etc. do not occupy resources. Tests multiple times as there is
* as there is not guarantee on how long it will take clean up the resources.
*/
- private def postCleanupValidate() {
+ private def postCleanupValidate(): Unit = {
// Verify the RDDs have been persisted and blocks are present
rddIds.foreach { rddId =>
assert(
diff --git a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala
index a5bdc95790722..1d3e28b39548f 100644
--- a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala
+++ b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala
@@ -21,7 +21,6 @@ import java.io.{FileDescriptor, InputStream}
import java.lang
import java.nio.ByteBuffer
-import scala.collection.JavaConverters._
import scala.collection.mutable
import org.apache.hadoop.fs._
diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala
index 182f28c5cce54..f58777584d0ae 100644
--- a/core/src/test/scala/org/apache/spark/DriverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala
@@ -50,7 +50,7 @@ class DriverSuite extends SparkFunSuite with TimeLimits {
* sys.exit() after finishing.
*/
object DriverWithoutCleanup {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
TestUtils.configTestLog4j("INFO")
val conf = new SparkConf
val sc = new SparkContext(args(0), "DriverWithoutCleanup", conf)
diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
index 07fb323cfc355..460714f204a3a 100644
--- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
@@ -64,7 +64,7 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite {
private def post(event: SparkListenerEvent): Unit = {
listenerBus.post(event)
- listenerBus.waitUntilEmpty(1000)
+ listenerBus.waitUntilEmpty()
}
test("initialize dynamic allocation in SparkContext") {
diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
index 7f7f3db65d6ca..8844a0598ccb8 100644
--- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
@@ -40,7 +40,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi
var transportContext: TransportContext = _
var rpcHandler: ExternalBlockHandler = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2)
rpcHandler = new ExternalBlockHandler(transportConf, null)
@@ -52,7 +52,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi
conf.set(config.SHUFFLE_SERVICE_PORT, server.getPort)
}
- override def afterAll() {
+ override def afterAll(): Unit = {
Utils.tryLogNonFatalError{
server.close()
}
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index 5f79b526a419b..8b75c3a0ba653 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -31,7 +31,7 @@ object FailureSuiteState {
var tasksRun = 0
var tasksFailed = 0
- def clear() {
+ def clear(): Unit = {
synchronized {
tasksRun = 0
tasksFailed = 0
diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala
index 6651e38f7ed62..c7ea195cc95e3 100644
--- a/core/src/test/scala/org/apache/spark/FileSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileSuite.scala
@@ -42,12 +42,12 @@ import org.apache.spark.util.Utils
class FileSuite extends SparkFunSuite with LocalSparkContext {
var tempDir: File = _
- override def beforeEach() {
+ override def beforeEach(): Unit = {
super.beforeEach()
tempDir = Utils.createTempDir()
}
- override def afterEach() {
+ override def afterEach(): Unit = {
try {
Utils.deleteRecursively(tempDir)
} finally {
diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
index e7eef8ec5150c..8433a6f52ac7a 100644
--- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
@@ -142,6 +142,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
sid,
taskContext.partitionId(),
taskContext.partitionId(),
+ taskContext.partitionId(),
"simulated fetch failure")
} else {
iter
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index b533304287cf6..94ad8d8880027 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -40,7 +40,7 @@ import org.apache.spark.util.ThreadUtils
class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAfter
with LocalSparkContext {
- override def afterEach() {
+ override def afterEach(): Unit = {
try {
resetSparkContext()
JobCancellationSuite.taskStartedSemaphore.drainPermits()
@@ -127,7 +127,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
// Add a listener to release the semaphore once any tasks are launched.
val sem = new Semaphore(0)
sc.addSparkListener(new SparkListener {
- override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
sem.release()
}
})
@@ -157,7 +157,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
// Add a listener to release the semaphore once any tasks are launched.
val sem = new Semaphore(0)
sc.addSparkListener(new SparkListener {
- override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
sem.release()
}
})
@@ -192,7 +192,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
// Add a listener to release the semaphore once any tasks are launched.
val sem = new Semaphore(0)
sc.addSparkListener(new SparkListener {
- override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
sem.release()
}
})
@@ -225,7 +225,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
// Add a listener to release the semaphore once any tasks are launched.
val sem = new Semaphore(0)
sc.addSparkListener(new SparkListener {
- override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
sem.release()
}
})
@@ -264,7 +264,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
// Add a listener to release the semaphore once any tasks are launched.
val sem = new Semaphore(0)
sc.addSparkListener(new SparkListener {
- override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
sem.release()
}
})
@@ -301,7 +301,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
sc = new SparkContext("local[2]", "test")
sc.addSparkListener(new SparkListener {
- override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
sem1.release()
}
})
@@ -391,7 +391,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
assert(executionOfInterruptibleCounter.get() < numElements)
}
- def testCount() {
+ def testCount(): Unit = {
// Cancel before launching any tasks
{
val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync()
@@ -405,7 +405,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
// Add a listener to release the semaphore once any tasks are launched.
val sem = new Semaphore(0)
sc.addSparkListener(new SparkListener {
- override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
sem.release()
}
})
@@ -421,7 +421,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
}
}
- def testTake() {
+ def testTake(): Unit = {
// Cancel before launching any tasks
{
val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000)
@@ -435,7 +435,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
// Add a listener to release the semaphore once any tasks are launched.
val sem = new Semaphore(0)
sc.addSparkListener(new SparkListener {
- override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
sem.release()
}
})
diff --git a/core/src/test/scala/org/apache/spark/JsonTestUtils.scala b/core/src/test/scala/org/apache/spark/JsonTestUtils.scala
index ba367cd476146..8aa7f3c7cb1bf 100644
--- a/core/src/test/scala/org/apache/spark/JsonTestUtils.scala
+++ b/core/src/test/scala/org/apache/spark/JsonTestUtils.scala
@@ -20,7 +20,7 @@ import org.json4s._
import org.json4s.jackson.JsonMethods
trait JsonTestUtils {
- def assertValidDataInJson(validateJson: JValue, expectedJson: JValue) {
+ def assertValidDataInJson(validateJson: JValue, expectedJson: JValue): Unit = {
val Diff(c, a, d) = validateJson.diff(expectedJson)
val validatePretty = JsonMethods.pretty(validateJson)
val expectedPretty = JsonMethods.pretty(expectedJson)
diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
index 05aaaa11451b4..d050ee2c45e7a 100644
--- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
+++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
@@ -27,12 +27,12 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self
@transient var sc: SparkContext = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
}
- override def afterEach() {
+ override def afterEach(): Unit = {
try {
resetSparkContext()
} finally {
@@ -48,7 +48,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self
}
object LocalSparkContext {
- def stop(sc: SparkContext) {
+ def stop(sc: SparkContext): Unit = {
if (sc != null) {
sc.stop()
}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index d86975964b558..da2ba2165bb0c 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -64,14 +64,15 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L))
tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
- Array(1000L, 10000L)))
+ Array(1000L, 10000L), 5))
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
- Array(10000L, 1000L)))
+ Array(10000L, 1000L), 6))
val statuses = tracker.getMapSizesByExecutorId(10, 0)
assert(statuses.toSet ===
- Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))),
- (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000))))
- .toSet)
+ Seq((BlockManagerId("a", "hostA", 1000),
+ ArrayBuffer((ShuffleBlockId(10, 5, 0), size1000, 0))),
+ (BlockManagerId("b", "hostB", 1000),
+ ArrayBuffer((ShuffleBlockId(10, 6, 0), size10000, 1)))).toSet)
assert(0 == tracker.getNumCachedSerializedBroadcast)
tracker.stop()
rpcEnv.shutdown()
@@ -86,9 +87,9 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val compressedSize1000 = MapStatus.compressSize(1000L)
val compressedSize10000 = MapStatus.compressSize(10000L)
tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
- Array(compressedSize1000, compressedSize10000)))
+ Array(compressedSize1000, compressedSize10000), 5))
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
- Array(compressedSize10000, compressedSize1000)))
+ Array(compressedSize10000, compressedSize1000), 6))
assert(tracker.containsShuffle(10))
assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty)
assert(0 == tracker.getNumCachedSerializedBroadcast)
@@ -109,9 +110,9 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val compressedSize1000 = MapStatus.compressSize(1000L)
val compressedSize10000 = MapStatus.compressSize(10000L)
tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
- Array(compressedSize1000, compressedSize1000, compressedSize1000)))
+ Array(compressedSize1000, compressedSize1000, compressedSize1000), 5))
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
- Array(compressedSize10000, compressedSize1000, compressedSize1000)))
+ Array(compressedSize10000, compressedSize1000, compressedSize1000), 6))
assert(0 == tracker.getNumCachedSerializedBroadcast)
// As if we had two simultaneous fetch failures
@@ -147,10 +148,11 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
masterTracker.registerMapOutput(10, 0, MapStatus(
- BlockManagerId("a", "hostA", 1000), Array(1000L)))
+ BlockManagerId("a", "hostA", 1000), Array(1000L), 5))
slaveTracker.updateEpoch(masterTracker.getEpoch)
assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
+ Seq((BlockManagerId("a", "hostA", 1000),
+ ArrayBuffer((ShuffleBlockId(10, 5, 0), size1000, 0)))))
assert(0 == masterTracker.getNumCachedSerializedBroadcast)
val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch
@@ -184,7 +186,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
// Message size should be ~123B, and no exception should be thrown
masterTracker.registerShuffle(10, 1)
masterTracker.registerMapOutput(10, 0, MapStatus(
- BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0)))
+ BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 5))
val senderAddress = RpcAddress("localhost", 12345)
val rpcCallContext = mock(classOf[RpcCallContext])
when(rpcCallContext.senderAddress).thenReturn(senderAddress)
@@ -218,11 +220,11 @@ class MapOutputTrackerSuite extends SparkFunSuite {
// on hostB with output size 3
tracker.registerShuffle(10, 3)
tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
- Array(2L)))
+ Array(2L), 5))
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000),
- Array(2L)))
+ Array(2L), 6))
tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000),
- Array(3L)))
+ Array(3L), 7))
// When the threshold is 50%, only host A should be returned as a preferred location
// as it has 4 out of 7 bytes of output.
@@ -262,7 +264,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
masterTracker.registerShuffle(20, 100)
(0 until 100).foreach { i =>
masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
- BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0)))
+ BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5))
}
val senderAddress = RpcAddress("localhost", 12345)
val rpcCallContext = mock(classOf[RpcCallContext])
@@ -311,16 +313,18 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L))
tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
- Array(size0, size1000, size0, size10000)))
+ Array(size0, size1000, size0, size10000), 5))
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
- Array(size10000, size0, size1000, size0)))
+ Array(size10000, size0, size1000, size0), 6))
assert(tracker.containsShuffle(10))
- assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq ===
+ assert(tracker.getMapSizesByExecutorId(10, 0, 4, false).toSeq ===
Seq(
(BlockManagerId("a", "hostA", 1000),
- Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))),
+ Seq((ShuffleBlockId(10, 5, 1), size1000, 0),
+ (ShuffleBlockId(10, 5, 3), size10000, 0))),
(BlockManagerId("b", "hostB", 1000),
- Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000)))
+ Seq((ShuffleBlockId(10, 6, 0), size10000, 1),
+ (ShuffleBlockId(10, 6, 2), size1000, 1)))
)
)
diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
index 1aa1c421d792e..bdeb631878350 100644
--- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
+++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
@@ -43,12 +43,12 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel
}
}
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
initializeContext()
}
- override def afterAll() {
+ override def afterAll(): Unit = {
try {
LocalSparkContext.stop(_sc)
_sc = null
diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
index 73638d9b131ea..378a361845139 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
@@ -23,7 +23,7 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll {
// This test suite should run all tests in ShuffleSuite with Netty shuffle mode.
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
conf.set("spark.shuffle.blockTransferService", "netty")
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 923c9c90447fd..c652f879cc8f9 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListene
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.shuffle.ShuffleWriter
import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId}
-import org.apache.spark.util.{MutablePair, Utils}
+import org.apache.spark.util.MutablePair
abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext {
@@ -360,7 +360,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
val metricsSystem = sc.env.metricsSystem
val shuffleMapRdd = new MyRDD(sc, 1, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
- val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep)
+ val shuffleHandle = manager.registerShuffle(0, shuffleDep)
mapTrackerMaster.registerShuffle(0, 1)
// first attempt -- its successful
@@ -487,7 +487,7 @@ object ShuffleSuite {
@volatile var bytesWritten: Long = 0
@volatile var bytesRead: Long = 0
val listener = new SparkListener {
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
recordsWritten += taskEnd.taskMetrics.shuffleWriteMetrics.recordsWritten
bytesWritten += taskEnd.taskMetrics.shuffleWriteMetrics.bytesWritten
recordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead
@@ -498,7 +498,7 @@ object ShuffleSuite {
job
- sc.listenerBus.waitUntilEmpty(500)
+ sc.listenerBus.waitUntilEmpty()
AggregatedShuffleMetrics(recordsWritten, recordsRead, bytesWritten, bytesRead)
}
}
diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
index 1aceda498d7c7..1a563621a5179 100644
--- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
@@ -37,7 +37,7 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
private var tempDir: File = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
// Once 'spark.local.dir' is set, it is cached. Unless this is manually cleared
// before/after a test, it could return the same directory even if this property
diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
index 9f00131c8dc20..0ac6ba2d76c6f 100644
--- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
@@ -463,7 +463,7 @@ class Class2 {}
class Class3 {}
class CustomRegistrator extends KryoRegistrator {
- def registerClasses(kryo: Kryo) {
+ def registerClasses(kryo: Kryo): Unit = {
kryo.register(classOf[Class2])
}
}
diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
index 536b4aec75623..6271ce507fddb 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
@@ -63,7 +63,7 @@ class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext {
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
assert(sc.getRDDStorageInfo.length === 0)
rdd.collect()
- sc.listenerBus.waitUntilEmpty(10000)
+ sc.listenerBus.waitUntilEmpty()
eventually(timeout(10.seconds), interval(100.milliseconds)) {
assert(sc.getRDDStorageInfo.length === 1)
}
@@ -82,7 +82,7 @@ class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext {
package object testPackage extends Assertions {
private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r
- def runCallSiteTest(sc: SparkContext) {
+ def runCallSiteTest(sc: SparkContext): Unit = {
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
val rddCreationSite = rdd.getCreationSite
val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd"
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
index 786f55c96a3e8..4fd862888dcc6 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
@@ -450,7 +450,9 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
sc.setLocalProperty("testProperty", "testValue")
var result = "unset";
- val thread = new Thread() { override def run() = {result = sc.getLocalProperty("testProperty")}}
+ val thread = new Thread() {
+ override def run(): Unit = {result = sc.getLocalProperty("testProperty")}
+ }
thread.start()
thread.join()
sc.stop()
@@ -461,10 +463,10 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
var result = "unset";
val thread1 = new Thread() {
- override def run() = {sc.setLocalProperty("testProperty", "testValue")}}
+ override def run(): Unit = {sc.setLocalProperty("testProperty", "testValue")}}
// testProperty should be unset and thus return null
val thread2 = new Thread() {
- override def run() = {result = sc.getLocalProperty("testProperty")}}
+ override def run(): Unit = {result = sc.getLocalProperty("testProperty")}}
thread1.start()
thread1.join()
thread2.start()
@@ -705,7 +707,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
if (context.stageAttemptNumber == 0) {
if (context.partitionId == 0) {
// Make the first task in the first stage attempt fail.
- throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 0, 0, 0,
+ throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 0, 0L, 0, 0,
new java.io.IOException("fake"))
} else {
// Make the second task in the first stage attempt sleep to generate a zombie task
@@ -716,7 +718,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
}
x
}.collect()
- sc.listenerBus.waitUntilEmpty(10000)
+ sc.listenerBus.waitUntilEmpty()
// As executors will send the metrics of running tasks via heartbeat, we can use this to check
// whether there is any running task.
eventually(timeout(10.seconds)) {
@@ -761,7 +763,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
sc = new SparkContext(conf)
// Ensure all executors has started
- TestUtils.waitUntilExecutorsUp(sc, 1, 10000)
+ TestUtils.waitUntilExecutorsUp(sc, 1, 60000)
assert(sc.resources.size === 1)
assert(sc.resources.get(GPU).get.addresses === Array("5", "6"))
assert(sc.resources.get(GPU).get.name === "gpu")
@@ -790,7 +792,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
sc = new SparkContext(conf)
// Ensure all executors has started
- TestUtils.waitUntilExecutorsUp(sc, 1, 10000)
+ TestUtils.waitUntilExecutorsUp(sc, 1, 60000)
// driver gpu resources file should take precedence over the script
assert(sc.resources.size === 1)
assert(sc.resources.get(GPU).get.addresses === Array("0", "1", "8"))
diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala
index 5cf9c087e1dcb..bb04d0d263253 100644
--- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala
@@ -29,7 +29,7 @@ object ThreadingSuiteState {
val runningThreads = new AtomicInteger
val failed = new AtomicBoolean
- def clear() {
+ def clear(): Unit = {
runningThreads.set(0)
failed.set(false)
}
@@ -44,7 +44,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
@volatile var answer1: Int = 0
@volatile var answer2: Int = 0
new Thread {
- override def run() {
+ override def run(): Unit = {
answer1 = nums.reduce(_ + _)
answer2 = nums.first() // This will run "locally" in the current thread
sem.release()
@@ -62,7 +62,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
@volatile var ok = true
for (i <- 0 until 10) {
new Thread {
- override def run() {
+ override def run(): Unit = {
val answer1 = nums.reduce(_ + _)
if (answer1 != 55) {
printf("In thread %d: answer1 was %d\n", i, answer1)
@@ -90,7 +90,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
@volatile var ok = true
for (i <- 0 until 10) {
new Thread {
- override def run() {
+ override def run(): Unit = {
val answer1 = nums.reduce(_ + _)
if (answer1 != 55) {
printf("In thread %d: answer1 was %d\n", i, answer1)
@@ -121,7 +121,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
var throwable: Option[Throwable] = None
for (i <- 0 until 2) {
new Thread {
- override def run() {
+ override def run(): Unit = {
try {
val ans = nums.map(number => {
val running = ThreadingSuiteState.runningThreads
@@ -161,7 +161,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
var throwable: Option[Throwable] = None
val threads = (1 to 5).map { i =>
new Thread() {
- override def run() {
+ override def run(): Unit = {
try {
sc.setLocalProperty("test", i.toString)
assert(sc.getLocalProperty("test") === i.toString)
@@ -189,7 +189,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
var throwable: Option[Throwable] = None
val threads = (1 to 5).map { i =>
new Thread() {
- override def run() {
+ override def run(): Unit = {
try {
assert(sc.getLocalProperty("test") === "parent")
sc.setLocalProperty("test", i.toString)
diff --git a/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala b/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala
index 73f9d0e2bc0e1..022fcbb25b0af 100644
--- a/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala
+++ b/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala
@@ -141,12 +141,14 @@ private[spark] class Benchmark(
val minIters = if (overrideNumIters != 0) overrideNumIters else minNumIters
val minDuration = if (overrideNumIters != 0) 0 else minTime.toNanos
val runTimes = ArrayBuffer[Long]()
+ var totalTime = 0L
var i = 0
- while (i < minIters || runTimes.sum < minDuration) {
+ while (i < minIters || totalTime < minDuration) {
val timer = new Benchmark.Timer(i)
f(timer)
val runTime = timer.totalTime()
runTimes += runTime
+ totalTime += runTime
if (outputPerIteration) {
// scalastyle:off
diff --git a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala
index a6666db4e95c3..55e34b32fe0d4 100644
--- a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala
+++ b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala
@@ -21,6 +21,7 @@ import java.io.{File, FileOutputStream, OutputStream}
/**
* A base class for generate benchmark results to a file.
+ * For JDK9+, JDK major version number is added to the file names to distingush the results.
*/
abstract class BenchmarkBase {
var output: Option[OutputStream] = None
@@ -43,7 +44,9 @@ abstract class BenchmarkBase {
def main(args: Array[String]): Unit = {
val regenerateBenchmarkFiles: Boolean = System.getenv("SPARK_GENERATE_BENCHMARK_FILES") == "1"
if (regenerateBenchmarkFiles) {
- val resultFileName = s"${this.getClass.getSimpleName.replace("$", "")}-results.txt"
+ val version = System.getProperty("java.version").split("\\D+")(0).toInt
+ val jdkString = if (version > 8) s"-jdk$version" else ""
+ val resultFileName = s"${this.getClass.getSimpleName.replace("$", "")}$jdkString-results.txt"
val file = new File(s"benchmarks/$resultFileName")
if (!file.exists()) {
file.createNewFile()
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index 66b2f487dc1cb..a6776ee077894 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -194,11 +194,12 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio
* In between each step, this test verifies that the broadcast blocks are present only on the
* expected nodes.
*/
- private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
+ private def testUnpersistTorrentBroadcast(distributed: Boolean,
+ removeFromDriver: Boolean): Unit = {
val numSlaves = if (distributed) 2 else 0
// Verify that blocks are persisted only on the driver
- def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) {
+ def afterCreation(broadcastId: Long, bmm: BlockManagerMaster): Unit = {
var blockId = BroadcastBlockId(broadcastId)
var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === 1)
@@ -209,7 +210,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio
}
// Verify that blocks are persisted in both the executors and the driver
- def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) {
+ def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster): Unit = {
var blockId = BroadcastBlockId(broadcastId)
val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === numSlaves + 1)
@@ -220,7 +221,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio
// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
// is true.
- def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) {
+ def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster): Unit = {
var blockId = BroadcastBlockId(broadcastId)
var expectedNumBlocks = if (removeFromDriver) 0 else 1
var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
@@ -251,7 +252,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio
afterCreation: (Long, BlockManagerMaster) => Unit,
afterUsingBroadcast: (Long, BlockManagerMaster) => Unit,
afterUnpersist: (Long, BlockManagerMaster) => Unit,
- removeFromDriver: Boolean) {
+ removeFromDriver: Boolean): Unit = {
sc = if (distributed) {
val _sc =
@@ -307,7 +308,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio
package object testPackage extends Assertions {
- def runCallSiteTest(sc: SparkContext) {
+ def runCallSiteTest(sc: SparkContext): Unit = {
val broadcast = sc.broadcast(Array(1, 2, 3, 4))
broadcast.destroy(blocking = true)
val thrown = intercept[SparkException] { broadcast.value }
diff --git a/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceDbSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceDbSuite.scala
index 9cfb8a647ad89..6914714dce6eb 100644
--- a/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceDbSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceDbSuite.scala
@@ -46,7 +46,7 @@ class ExternalShuffleServiceDbSuite extends SparkFunSuite {
var blockHandler: ExternalBlockHandler = _
var blockResolver: ExternalShuffleBlockResolver = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
sparkConf = new SparkConf()
sparkConf.set("spark.shuffle.service.enabled", "true")
@@ -63,7 +63,7 @@ class ExternalShuffleServiceDbSuite extends SparkFunSuite {
registerExecutor()
}
- override def afterAll() {
+ override def afterAll(): Unit = {
try {
dataContext.cleanup()
} finally {
diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
index ad402c0e905ae..eeccf56cbf02e 100644
--- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
@@ -89,7 +89,7 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils {
assertValidDataInJson(output, JsonMethods.parse(JsonConstants.workerStateJsonStr))
}
- def assertValidJson(json: JValue) {
+ def assertValidJson(json: JValue): Unit = {
try {
JsonMethods.parse(JsonMethods.compact(json))
} catch {
diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala
index cbdf1755b0c5b..84fc16979925b 100644
--- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala
@@ -29,9 +29,6 @@ import org.apache.spark.util.SparkConfWithEnv
class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext {
- /** Length of time to wait while draining listener events. */
- private val WAIT_TIMEOUT_MILLIS = 10000
-
test("verify that correct log urls get propagated from workers") {
sc = new SparkContext("local-cluster[2,1,1024]", "test")
@@ -41,7 +38,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext {
// Trigger a job so that executors get added
sc.parallelize(1 to 100, 4).map(_.toString).count()
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ sc.listenerBus.waitUntilEmpty()
listener.addedExecutorInfos.values.foreach { info =>
assert(info.logUrlMap.nonEmpty)
// Browse to each URL to check that it's valid
@@ -61,7 +58,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext {
// Trigger a job so that executors get added
sc.parallelize(1 to 100, 4).map(_.toString).count()
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ sc.listenerBus.waitUntilEmpty()
val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo]
assert(listeners.size === 1)
val listener = listeners(0)
@@ -77,7 +74,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext {
private[spark] class SaveExecutorInfo extends SparkListener {
val addedExecutorInfos = mutable.Map[String, ExecutorInfo]()
- override def onExecutorAdded(executor: SparkListenerExecutorAdded) {
+ override def onExecutorAdded(executor: SparkListenerExecutorAdded): Unit = {
addedExecutorInfos(executor.executorId) = executor.executorInfo
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala
index ef947eb074647..d04d9b6dcb2be 100644
--- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala
@@ -58,7 +58,7 @@ class RPackageUtilsSuite
/** Simple PrintStream that reads data into a buffer */
private class BufferPrintStream extends PrintStream(noOpOutputStream) {
// scalastyle:off println
- override def println(line: String) {
+ override def println(line: String): Unit = {
// scalastyle:on println
lineBuffer += line
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 385f549aa1ad9..5b81671edb149 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -57,7 +57,7 @@ trait TestPrematureExit {
private class BufferPrintStream extends PrintStream(noOpOutputStream) {
var lineBuffer = ArrayBuffer[String]()
// scalastyle:off println
- override def println(line: String) {
+ override def println(line: String): Unit = {
lineBuffer += line
}
// scalastyle:on println
@@ -121,7 +121,7 @@ class SparkSubmitSuite
private val submit = new SparkSubmit()
- override def beforeEach() {
+ override def beforeEach(): Unit = {
super.beforeEach()
}
@@ -600,7 +600,7 @@ class SparkSubmitSuite
}
// TODO(SPARK-9603): Building a package is flaky on Jenkins Maven builds.
- // See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log
+ // See https://gist.github.com/shivaram/3a2fecce60768a603dac for an error log
ignore("correctly builds R packages included in a jar with --packages") {
assume(RUtils.isRInstalled, "R isn't installed on this machine.")
assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.")
@@ -1365,7 +1365,7 @@ object SparkSubmitSuite extends SparkFunSuite with TimeLimits {
}
object JarCreationTest extends Logging {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
TestUtils.configTestLog4j("INFO")
val conf = new SparkConf()
val sc = new SparkContext(conf)
@@ -1389,7 +1389,7 @@ object JarCreationTest extends Logging {
}
object SimpleApplicationTest {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
TestUtils.configTestLog4j("INFO")
val conf = new SparkConf()
val sc = new SparkContext(conf)
@@ -1415,7 +1415,7 @@ object SimpleApplicationTest {
}
object UserClasspathFirstTest {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val ccl = Thread.currentThread().getContextClassLoader()
val resource = ccl.getResourceAsStream("test.resource")
val bytes = ByteStreams.toByteArray(resource)
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
index 8e1a519e187ce..31e6c730eadc0 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
@@ -44,13 +44,13 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll {
private class BufferPrintStream extends PrintStream(noOpOutputStream) {
var lineBuffer = ArrayBuffer[String]()
// scalastyle:off println
- override def println(line: String) {
+ override def println(line: String): Unit = {
lineBuffer += line
}
// scalastyle:on println
}
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
// We don't want to write logs during testing
SparkSubmitUtils.printStream = new BufferPrintStream
diff --git a/core/src/test/scala/org/apache/spark/deploy/client/TestExecutor.scala b/core/src/test/scala/org/apache/spark/deploy/client/TestExecutor.scala
index a98b1fa8f83a1..1dce49d1f9d5a 100644
--- a/core/src/test/scala/org/apache/spark/deploy/client/TestExecutor.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/client/TestExecutor.scala
@@ -18,7 +18,7 @@
package org.apache.spark.deploy.client
private[spark] object TestExecutor {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
// scalastyle:off println
println("Hello world!")
// scalastyle:on println
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala
index 1148446c9faa1..48bd088d07ff9 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala
@@ -28,7 +28,7 @@ import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.scalatest.Matchers
-import org.scalatest.mockito.MockitoSugar
+import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
index 30261dde678f1..1d465ba37364b 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
@@ -86,7 +86,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with Matchers with Logging {
}
}
- private def testAppLogParsing(inMemory: Boolean) {
+ private def testAppLogParsing(inMemory: Boolean): Unit = {
val clock = new ManualClock(12345678)
val conf = createTestConf(inMemory = inMemory)
val provider = new FsHistoryProvider(conf, clock)
@@ -1254,7 +1254,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with Matchers with Logging {
private def writeFile(file: File, codec: Option[CompressionCodec],
events: SparkListenerEvent*) = {
val fstream = new FileOutputStream(file)
- val cstream = codec.map(_.compressedOutputStream(fstream)).getOrElse(fstream)
+ val cstream = codec.map(_.compressedContinuousOutputStream(fstream)).getOrElse(fstream)
val bstream = new BufferedOutputStream(cstream)
EventLoggingListener.initEventLog(bstream, false, null)
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
index dbc1938ed469a..17c1528b67588 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
@@ -26,7 +26,6 @@ import javax.servlet.http.{HttpServletRequest, HttpServletRequestWrapper, HttpSe
import scala.collection.JavaConverters._
import scala.concurrent.duration._
-import com.gargoylesoftware.htmlunit.BrowserVersion
import com.google.common.io.{ByteStreams, Files}
import org.apache.commons.io.{FileUtils, IOUtils}
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
@@ -40,8 +39,8 @@ import org.openqa.selenium.WebDriver
import org.openqa.selenium.htmlunit.HtmlUnitDriver
import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually
-import org.scalatest.mockito.MockitoSugar
-import org.scalatest.selenium.WebBrowser
+import org.scalatestplus.mockito.MockitoSugar
+import org.scalatestplus.selenium.WebBrowser
import org.apache.spark._
import org.apache.spark.internal.config._
@@ -94,6 +93,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
server = new HistoryServer(conf, provider, securityManager, 18080)
server.initialize()
server.bind()
+ provider.start()
port = server.boundPort
}
@@ -364,8 +364,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
contextHandler.addServlet(holder, "/")
server.attachHandler(contextHandler)
- implicit val webDriver: WebDriver =
- new HtmlUnitDriver(BrowserVersion.INTERNET_EXPLORER_11, true)
+ implicit val webDriver: WebDriver = new HtmlUnitDriver(true)
try {
val url = s"http://localhost:$port"
@@ -451,6 +450,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
server = new HistoryServer(myConf, provider, securityManager, 0)
server.initialize()
server.bind()
+ provider.start()
val port = server.boundPort
val metrics = server.cacheMetrics
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala
index f4558aa3eb893..e2d7facdd77e0 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala
@@ -47,12 +47,12 @@ class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll {
when(master.self).thenReturn(masterEndpointRef)
val masterWebUI = new MasterWebUI(master, 0)
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
masterWebUI.bind()
}
- override def afterAll() {
+ override def afterAll(): Unit = {
try {
masterWebUI.stop()
} finally {
diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
index 89b8bb4ff7d03..d5312845a3b50 100644
--- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
@@ -42,7 +42,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach {
private var rpcEnv: Option[RpcEnv] = None
private var server: Option[RestSubmissionServer] = None
- override def afterEach() {
+ override def afterEach(): Unit = {
try {
rpcEnv.foreach(_.shutdown())
server.foreach(_.stop())
diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala
index 70174f7ff939a..275bca3459855 100644
--- a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala
@@ -17,11 +17,17 @@
package org.apache.spark.deploy.security
+import java.security.PrivilegedExceptionAction
+
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.security.Credentials
+import org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION
+import org.apache.hadoop.minikdc.MiniKdc
+import org.apache.hadoop.security.{Credentials, UserGroupInformation}
import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.security.HadoopDelegationTokenProvider
+import org.apache.spark.util.Utils
private class ExceptionThrowingDelegationTokenProvider extends HadoopDelegationTokenProvider {
ExceptionThrowingDelegationTokenProvider.constructed = true
@@ -69,4 +75,48 @@ class HadoopDelegationTokenManagerSuite extends SparkFunSuite {
assert(!manager.isProviderLoaded("hadoopfs"))
assert(manager.isProviderLoaded("hbase"))
}
+
+ test("SPARK-29082: do not fail if current user does not have credentials") {
+ // SparkHadoopUtil overrides the UGI configuration during initialization. That normally
+ // happens early in the Spark application, but here it may affect the test depending on
+ // how it's run, so force its initialization.
+ SparkHadoopUtil.get
+
+ var kdc: MiniKdc = null
+ try {
+ // UserGroupInformation.setConfiguration needs default kerberos realm which can be set in
+ // krb5.conf. MiniKdc sets "java.security.krb5.conf" in start and removes it when stop called.
+ val kdcDir = Utils.createTempDir()
+ val kdcConf = MiniKdc.createConf()
+ kdc = new MiniKdc(kdcConf, kdcDir)
+ kdc.start()
+
+ val krbConf = new Configuration()
+ krbConf.set(HADOOP_SECURITY_AUTHENTICATION, "kerberos")
+
+ UserGroupInformation.setConfiguration(krbConf)
+ val manager = new HadoopDelegationTokenManager(new SparkConf(false), krbConf, null)
+ val testImpl = new PrivilegedExceptionAction[Unit] {
+ override def run(): Unit = {
+ assert(UserGroupInformation.isSecurityEnabled())
+ val creds = new Credentials()
+ manager.obtainDelegationTokens(creds)
+ assert(creds.numberOfTokens() === 0)
+ assert(creds.numberOfSecretKeys() === 0)
+ }
+ }
+
+ val realUser = UserGroupInformation.createUserForTesting("realUser", Array.empty)
+ realUser.doAs(testImpl)
+
+ val proxyUser = UserGroupInformation.createProxyUserForTesting("proxyUser", realUser,
+ Array.empty)
+ proxyUser.doAs(testImpl)
+ } finally {
+ if (kdc != null) {
+ kdc.stop()
+ }
+ UserGroupInformation.reset()
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProviderSuite.scala
index 1f19884bc24d3..44f38e7043dcd 100644
--- a/core/src/test/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProviderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProviderSuite.scala
@@ -22,14 +22,15 @@ import org.apache.hadoop.fs.Path
import org.scalatest.Matchers
import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.internal.config.STAGING_DIR
+import org.apache.spark.internal.config.{STAGING_DIR, SUBMIT_DEPLOY_MODE}
class HadoopFSDelegationTokenProviderSuite extends SparkFunSuite with Matchers {
test("hadoopFSsToAccess should return defaultFS even if not configured") {
val sparkConf = new SparkConf()
val defaultFS = "hdfs://localhost:8020"
val statingDir = "hdfs://localhost:8021"
- sparkConf.set("spark.master", "yarn-client")
+ sparkConf.setMaster("yarn")
+ sparkConf.set(SUBMIT_DEPLOY_MODE, "client")
sparkConf.set(STAGING_DIR, statingDir)
val hadoopConf = new Configuration()
hadoopConf.set("fs.defaultFS", defaultFS)
diff --git a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
index 64d99a59b9192..c34263dd17128 100644
--- a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.executor
-import java.io.File
import java.net.URL
import java.nio.ByteBuffer
import java.util.Properties
@@ -30,7 +29,7 @@ import org.json4s.JsonAST.{JArray, JObject}
import org.json4s.JsonDSL._
import org.mockito.Mockito.when
import org.scalatest.concurrent.Eventually.{eventually, timeout}
-import org.scalatest.mockito.MockitoSugar
+import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark._
import org.apache.spark.TestUtils._
@@ -136,7 +135,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
// not enough gpu's on the executor
withTempDir { tmpDir =>
val gpuArgs = ResourceAllocation(EXECUTOR_GPU_ID, Seq("0"))
- val ja = Extraction.decompose(Seq(gpuArgs))
+ val ja = Extraction.decompose(Seq(gpuArgs))
val f1 = createTempJsonFile(tmpDir, "resources", ja)
var error = intercept[IllegalArgumentException] {
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index ac7e4b51ebc2b..621151a39eea6 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -35,7 +35,7 @@ import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.PrivateMethodTester
import org.scalatest.concurrent.Eventually
-import org.scalatest.mockito.MockitoSugar
+import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
@@ -56,7 +56,7 @@ import org.apache.spark.util.{LongAccumulator, UninterruptibleThread}
class ExecutorSuite extends SparkFunSuite
with LocalSparkContext with MockitoSugar with Eventually with PrivateMethodTester {
- override def afterEach() {
+ override def afterEach(): Unit = {
// Unset any latches after each test; each test that needs them initializes new ones.
ExecutorSuiteHelper.latches = null
super.afterEach()
@@ -528,7 +528,8 @@ class FetchFailureThrowingRDD(sc: SparkContext) extends RDD[Int](sc, Nil) {
throw new FetchFailedException(
bmAddress = BlockManagerId("1", "hostA", 1234),
shuffleId = 0,
- mapId = 0,
+ mapId = 0L,
+ mapIndex = 0,
reduceId = 0,
message = "fake fetch failure"
)
diff --git a/core/src/test/scala/org/apache/spark/executor/ProcfsMetricsGetterSuite.scala b/core/src/test/scala/org/apache/spark/executor/ProcfsMetricsGetterSuite.scala
index 9ed1497db5e1d..9836697e1647c 100644
--- a/core/src/test/scala/org/apache/spark/executor/ProcfsMetricsGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ProcfsMetricsGetterSuite.scala
@@ -22,9 +22,9 @@ import org.apache.spark.SparkFunSuite
class ProcfsMetricsGetterSuite extends SparkFunSuite {
- val p = new ProcfsMetricsGetter(getTestResourcePath("ProcfsMetrics"))
test("testGetProcessInfo") {
+ val p = new ProcfsMetricsGetter(getTestResourcePath("ProcfsMetrics"))
var r = ProcfsMetrics(0, 0, 0, 0, 0, 0)
r = p.addProcfsMetricsFromOneProcess(r, 26109)
assert(r.jvmVmemTotal == 4769947648L)
diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala
index 576ca1613f75e..9a21ea6dafcac 100644
--- a/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala
+++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala
@@ -25,7 +25,6 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.internal.Logging
-import org.apache.spark.util.Utils
/**
* Tests the correctness of
@@ -35,13 +34,13 @@ import org.apache.spark.util.Utils
class WholeTextFileInputFormatSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {
private var sc: SparkContext = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
val conf = new SparkConf()
sc = new SparkContext("local", "test", conf)
}
- override def afterAll() {
+ override def afterAll(): Unit = {
try {
sc.stop()
} finally {
diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala
index 47552916adb22..fab7aea6c47aa 100644
--- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala
@@ -40,7 +40,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl
private var sc: SparkContext = _
private var factory: CompressionCodecFactory = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
// Hadoop's FileSystem caching does not use the Configuration as part of its cache key, which
// can cause Filesystem.get(Configuration) to return a cached instance created with a different
// configuration than the one passed to get() (see HADOOP-8490 for more details). This caused
@@ -59,7 +59,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl
factory = new CompressionCodecFactory(sc.hadoopConfiguration)
}
- override def afterAll() {
+ override def afterAll(): Unit = {
try {
sc.stop()
} finally {
diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala
index a6b0654204f34..551c0f1a73241 100644
--- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala
+++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala
@@ -23,7 +23,7 @@ import scala.util.Random
import org.mockito.Mockito.when
import org.scalatest.BeforeAndAfterEach
-import org.scalatest.mockito.MockitoSugar
+import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite}
import org.apache.spark.internal.config
diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
index 7b40e3e58216d..4b27396e6ae05 100644
--- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
+++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite}
class CompressionCodecSuite extends SparkFunSuite {
val conf = new SparkConf(false)
- def testCodec(codec: CompressionCodec) {
+ def testCodec(codec: CompressionCodec): Unit = {
// Write 1000 integers to the output stream, compressed.
val outputStream = new ByteArrayOutputStream()
val out = codec.compressedOutputStream(outputStream)
diff --git a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala
index c26945fa5fa31..60f67699f81be 100644
--- a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala
+++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala
@@ -17,60 +17,110 @@
package org.apache.spark.memory
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.mutable
+
import org.apache.spark.SparkConf
import org.apache.spark.storage.BlockId
class TestMemoryManager(conf: SparkConf)
extends MemoryManager(conf, numCores = 1, Long.MaxValue, Long.MaxValue) {
+ @GuardedBy("this")
+ private var consequentOOM = 0
+ @GuardedBy("this")
+ private var available = Long.MaxValue
+ @GuardedBy("this")
+ private val memoryForTask = mutable.HashMap[Long, Long]().withDefaultValue(0L)
+
override private[memory] def acquireExecutionMemory(
numBytes: Long,
taskAttemptId: Long,
- memoryMode: MemoryMode): Long = {
- if (consequentOOM > 0) {
- consequentOOM -= 1
- 0
- } else if (available >= numBytes) {
- available -= numBytes
- numBytes
- } else {
- val grant = available
- available = 0
- grant
+ memoryMode: MemoryMode): Long = synchronized {
+ require(numBytes >= 0)
+ val acquired = {
+ if (consequentOOM > 0) {
+ consequentOOM -= 1
+ 0
+ } else if (available >= numBytes) {
+ available -= numBytes
+ numBytes
+ } else {
+ val grant = available
+ available = 0
+ grant
+ }
}
+ memoryForTask(taskAttemptId) = memoryForTask.getOrElse(taskAttemptId, 0L) + acquired
+ acquired
+ }
+
+ override private[memory] def releaseExecutionMemory(
+ numBytes: Long,
+ taskAttemptId: Long,
+ memoryMode: MemoryMode): Unit = synchronized {
+ require(numBytes >= 0)
+ available += numBytes
+ val existingMemoryUsage = memoryForTask.getOrElse(taskAttemptId, 0L)
+ val newMemoryUsage = existingMemoryUsage - numBytes
+ require(
+ newMemoryUsage >= 0,
+ s"Attempting to free $numBytes of memory for task attempt $taskAttemptId, but it only " +
+ s"allocated $existingMemoryUsage bytes of memory")
+ memoryForTask(taskAttemptId) = newMemoryUsage
+ }
+
+ override private[memory] def releaseAllExecutionMemoryForTask(taskAttemptId: Long): Long = {
+ memoryForTask.remove(taskAttemptId).getOrElse(0L)
+ }
+
+ override private[memory] def getExecutionMemoryUsageForTask(taskAttemptId: Long): Long = {
+ memoryForTask.getOrElse(taskAttemptId, 0L)
}
+
override def acquireStorageMemory(
blockId: BlockId,
numBytes: Long,
- memoryMode: MemoryMode): Boolean = true
+ memoryMode: MemoryMode): Boolean = {
+ require(numBytes >= 0)
+ true
+ }
+
override def acquireUnrollMemory(
blockId: BlockId,
numBytes: Long,
- memoryMode: MemoryMode): Boolean = true
- override def releaseStorageMemory(numBytes: Long, memoryMode: MemoryMode): Unit = {}
- override private[memory] def releaseExecutionMemory(
- numBytes: Long,
- taskAttemptId: Long,
- memoryMode: MemoryMode): Unit = {
- available += numBytes
+ memoryMode: MemoryMode): Boolean = {
+ require(numBytes >= 0)
+ true
}
+
+ override def releaseStorageMemory(numBytes: Long, memoryMode: MemoryMode): Unit = {
+ require(numBytes >= 0)
+ }
+
override def maxOnHeapStorageMemory: Long = Long.MaxValue
override def maxOffHeapStorageMemory: Long = 0L
- private var consequentOOM = 0
- private var available = Long.MaxValue
-
+ /**
+ * Causes the next call to [[acquireExecutionMemory()]] to fail to allocate
+ * memory (returning `0`), simulating low-on-memory / out-of-memory conditions.
+ */
def markExecutionAsOutOfMemoryOnce(): Unit = {
markconsequentOOM(1)
}
- def markconsequentOOM(n : Int) : Unit = {
+ /**
+ * Causes the next `n` calls to [[acquireExecutionMemory()]] to fail to allocate
+ * memory (returning `0`), simulating low-on-memory / out-of-memory conditions.
+ */
+ def markconsequentOOM(n: Int): Unit = synchronized {
consequentOOM += n
}
- def limit(avail: Long): Unit = {
+ def limit(avail: Long): Unit = synchronized {
+ require(avail >= 0)
available = avail
}
-
}
diff --git a/core/src/test/scala/org/apache/spark/memory/TestMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManagerSuite.scala
new file mode 100644
index 0000000000000..043f341074b88
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManagerSuite.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+
+/**
+ * Tests of [[TestMemoryManager]] itself.
+ */
+class TestMemoryManagerSuite extends SparkFunSuite {
+ test("tracks allocated execution memory by task") {
+ val testMemoryManager = new TestMemoryManager(new SparkConf())
+
+ assert(testMemoryManager.getExecutionMemoryUsageForTask(0) == 0)
+ assert(testMemoryManager.getExecutionMemoryUsageForTask(1) == 0)
+
+ testMemoryManager.acquireExecutionMemory(10, 0, MemoryMode.ON_HEAP)
+ testMemoryManager.acquireExecutionMemory(5, 1, MemoryMode.ON_HEAP)
+ testMemoryManager.acquireExecutionMemory(5, 0, MemoryMode.ON_HEAP)
+ assert(testMemoryManager.getExecutionMemoryUsageForTask(0) == 15)
+ assert(testMemoryManager.getExecutionMemoryUsageForTask(1) == 5)
+
+ testMemoryManager.releaseExecutionMemory(10, 0, MemoryMode.ON_HEAP)
+ assert(testMemoryManager.getExecutionMemoryUsageForTask(0) == 5)
+
+ testMemoryManager.releaseAllExecutionMemoryForTask(0)
+ testMemoryManager.releaseAllExecutionMemoryForTask(1)
+ assert(testMemoryManager.getExecutionMemoryUsageForTask(0) == 0)
+ assert(testMemoryManager.getExecutionMemoryUsageForTask(1) == 0)
+ }
+
+ test("markconsequentOOM") {
+ val testMemoryManager = new TestMemoryManager(new SparkConf())
+ assert(testMemoryManager.acquireExecutionMemory(1, 0, MemoryMode.ON_HEAP) == 1)
+ testMemoryManager.markconsequentOOM(2)
+ assert(testMemoryManager.acquireExecutionMemory(1, 0, MemoryMode.ON_HEAP) == 0)
+ assert(testMemoryManager.acquireExecutionMemory(1, 0, MemoryMode.ON_HEAP) == 0)
+ assert(testMemoryManager.acquireExecutionMemory(1, 0, MemoryMode.ON_HEAP) == 1)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
index c7bd0c905d027..dbcec647a3dbc 100644
--- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
@@ -166,7 +166,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
var shuffleRead = 0L
var shuffleWritten = 0L
sc.addSparkListener(new SparkListener() {
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val metrics = taskEnd.taskMetrics
inputRead += metrics.inputMetrics.recordsRead
outputWritten += metrics.outputMetrics.recordsWritten
@@ -182,7 +182,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
.reduceByKey(_ + _)
.saveAsTextFile(tmpFile.toURI.toString)
- sc.listenerBus.waitUntilEmpty(500)
+ sc.listenerBus.waitUntilEmpty()
assert(inputRead == numRecords)
assert(outputWritten == numBuckets)
@@ -243,17 +243,17 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
val taskMetrics = new ArrayBuffer[Long]()
// Avoid receiving earlier taskEnd events
- sc.listenerBus.waitUntilEmpty(500)
+ sc.listenerBus.waitUntilEmpty()
sc.addSparkListener(new SparkListener() {
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
taskMetrics += collector(taskEnd)
}
})
job
- sc.listenerBus.waitUntilEmpty(500)
+ sc.listenerBus.waitUntilEmpty()
taskMetrics.sum
}
@@ -284,7 +284,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
val taskBytesWritten = new ArrayBuffer[Long]()
sc.addSparkListener(new SparkListener() {
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
taskBytesWritten += taskEnd.taskMetrics.outputMetrics.bytesWritten
}
})
@@ -293,7 +293,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
try {
rdd.saveAsTextFile(outPath.toString)
- sc.listenerBus.waitUntilEmpty(500)
+ sc.listenerBus.waitUntilEmpty()
assert(taskBytesWritten.length == 2)
val outFiles = fs.listStatus(outPath).filter(_.getPath.getName != "_SUCCESS")
taskBytesWritten.zip(outFiles).foreach { case (bytes, fileStatus) =>
diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
index 544d52d48b385..e05fad19567ae 100644
--- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
@@ -29,7 +29,7 @@ import scala.util.{Failure, Success, Try}
import com.google.common.io.CharStreams
import org.mockito.Mockito._
import org.scalatest.Matchers
-import org.scalatest.mockito.MockitoSugar
+import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.internal.config._
diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala
index 5d67d3358a9ca..edddf88a28f85 100644
--- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala
@@ -41,7 +41,7 @@ class NettyBlockTransferServiceSuite
private var service0: NettyBlockTransferService = _
private var service1: NettyBlockTransferService = _
- override def afterEach() {
+ override def afterEach(): Unit = {
try {
if (service0 != null) {
service0.close()
diff --git a/core/src/test/scala/org/apache/spark/network/netty/SparkTransportConfSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/SparkTransportConfSuite.scala
index d7265b6c24fe7..55cd1a4bfe7dd 100644
--- a/core/src/test/scala/org/apache/spark/network/netty/SparkTransportConfSuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/netty/SparkTransportConfSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.network.netty
-import org.scalatest.Matchers
-import org.scalatest.mockito.MockitoSugar
+import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.network.util.NettyUtils
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index a7eb0eca72e56..faef953e9fb90 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -37,12 +37,12 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim
// Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x
implicit val defaultSignaler: Signaler = ThreadSignaler
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
sc = new SparkContext("local[2]", "test")
}
- override def afterAll() {
+ override def afterAll(): Unit = {
try {
LocalSparkContext.stop(sc)
sc = null
@@ -86,7 +86,7 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim
}
test("takeAsync") {
- def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) {
+ def testTake(rdd: RDD[Int], input: Seq[Int], num: Int): Unit = {
val expected = input.take(num)
val saw = rdd.takeAsync(num).get()
assert(saw == expected, "incorrect result for rdd with %d partitions (expected %s, saw %s)"
diff --git a/core/src/test/scala/org/apache/spark/rdd/CoalescedRDDBenchmark.scala b/core/src/test/scala/org/apache/spark/rdd/CoalescedRDDBenchmark.scala
index 42b30707f2624..617ca5a1a8bc4 100644
--- a/core/src/test/scala/org/apache/spark/rdd/CoalescedRDDBenchmark.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/CoalescedRDDBenchmark.scala
@@ -67,7 +67,8 @@ object CoalescedRDDBenchmark extends BenchmarkBase {
benchmark.run()
}
- private def performCoalesce(blocks: immutable.Seq[(Int, Seq[String])], numPartitions: Int) {
+ private def performCoalesce(blocks: immutable.Seq[(Int, Seq[String])],
+ numPartitions: Int): Unit = {
sc.makeRDD(blocks).coalesce(numPartitions).partitions
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 1564435a0bbae..01fe170073a10 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -200,7 +200,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
assert(sums.partitioner === Some(p))
// count the dependencies to make sure there is only 1 ShuffledRDD
val deps = new HashSet[RDD[_]]()
- def visit(r: RDD[_]) {
+ def visit(r: RDD[_]): Unit = {
for (dep <- r.dependencies) {
deps += dep.rdd
visit(dep.rdd)
diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala
index 424d9f825c465..10f4bbcf7f48b 100644
--- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala
@@ -22,7 +22,7 @@ import scala.collection.immutable.NumericRange
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
import org.scalacheck.Prop._
-import org.scalatest.prop.Checkers
+import org.scalatestplus.scalacheck.Checkers
import org.apache.spark.SparkFunSuite
diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
index cb0de1c6beb6b..da2ccbfae181f 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
@@ -25,7 +25,7 @@ class MockSampler extends RandomSampler[Long, Long] {
private var s: Long = _
- override def setSeed(seed: Long) {
+ override def setSeed(seed: Long): Unit = {
s = seed
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index 69739a2e58481..860cf4d7ed9b2 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -21,16 +21,18 @@ import java.io.File
import scala.collection.JavaConverters._
import scala.collection.Map
+import scala.concurrent.duration._
import scala.io.Codec
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat}
+import org.scalatest.concurrent.Eventually
import org.apache.spark._
import org.apache.spark.util.Utils
-class PipedRDDSuite extends SparkFunSuite with SharedSparkContext {
+class PipedRDDSuite extends SparkFunSuite with SharedSparkContext with Eventually {
val envCommand = if (Utils.isWindows) {
"cmd.exe /C set"
} else {
@@ -100,11 +102,16 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext {
assert(result.collect().length === 0)
- // collect stderr writer threads
- val stderrWriterThread = Thread.getAllStackTraces.keySet().asScala
- .find { _.getName.startsWith(PipedRDD.STDIN_WRITER_THREAD_PREFIX) }
-
- assert(stderrWriterThread.isEmpty)
+ // SPARK-29104 PipedRDD will invoke `stdinWriterThread.interrupt()` at task completion,
+ // and `obj.wait` will get InterruptedException. However, there exists a possibility
+ // which the thread termination gets delayed because the thread starts from `obj.wait()`
+ // with that exception. To prevent test flakiness, we need to use `eventually`.
+ eventually(timeout(10.seconds), interval(1.second)) {
+ // collect stdin writer threads
+ val stdinWriterThread = Thread.getAllStackTraces.keySet().asScala
+ .find { _.getName.startsWith(PipedRDD.STDIN_WRITER_THREAD_PREFIX) }
+ assert(stdinWriterThread.isEmpty)
+ }
}
test("advanced pipe") {
@@ -230,7 +237,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext {
testExportInputFile("mapreduce_map_input_file")
}
- def testExportInputFile(varName: String) {
+ def testExportInputFile(varName: String): Unit = {
assume(TestUtils.testCommandAvailable(envCommand))
val nums = new HadoopRDD(sc, new JobConf(), classOf[TextInputFormat], classOf[LongWritable],
classOf[Text], 2) {
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 60e63bfd68625..859c25ff03819 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -366,7 +366,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually {
assert(math.abs(partitions1(1).length - 500) < initialPartitions)
assert(repartitioned1.collect() === input)
- def testSplitPartitions(input: Seq[Int], initialPartitions: Int, finalPartitions: Int) {
+ def testSplitPartitions(input: Seq[Int], initialPartitions: Int, finalPartitions: Int): Unit = {
val data = sc.parallelize(input, initialPartitions)
val repartitioned = data.repartition(finalPartitions)
assert(repartitioned.partitions.size === finalPartitions)
@@ -1099,7 +1099,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually {
override def index: Int = 0
})
override def getDependencies: Seq[Dependency[_]] = mutableDependencies
- def addDependency(dep: Dependency[_]) {
+ def addDependency(dep: Dependency[_]): Unit = {
mutableDependencies += dep
}
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 5bdf71be35b3b..5929fbf85a1f4 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -409,7 +409,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
(0 until 10) foreach { _ =>
new Thread {
- override def run() {
+ override def run(): Unit = {
(0 until 100) foreach { _ =>
endpointRef.send("Hello")
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
index 59b4b706bbcdd..378d433cf44f8 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
@@ -22,7 +22,7 @@ import java.util.concurrent.ExecutionException
import scala.concurrent.duration._
import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
-import org.scalatest.mockito.MockitoSugar
+import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark._
import org.apache.spark.network.client.TransportClient
diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala
index 8d5f04ac7651a..fc8ac38479932 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala
@@ -26,13 +26,18 @@ import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY
class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
- test("global sync by barrier() call") {
+ def initLocalClusterSparkContext(): Unit = {
val conf = new SparkConf()
// Init local cluster here so each barrier task runs in a separated process, thus `barrier()`
// call is actually useful.
.setMaster("local-cluster[4, 1, 1024]")
.setAppName("test-cluster")
+ .set(TEST_NO_STAGE_RETRY, true)
sc = new SparkContext(conf)
+ }
+
+ test("global sync by barrier() call") {
+ initLocalClusterSparkContext()
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
@@ -48,10 +53,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
}
test("support multiple barrier() call within a single task") {
- val conf = new SparkConf()
- .setMaster("local-cluster[4, 1, 1024]")
- .setAppName("test-cluster")
- sc = new SparkContext(conf)
+ initLocalClusterSparkContext()
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
@@ -77,12 +79,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
}
test("throw exception on barrier() call timeout") {
- val conf = new SparkConf()
- .set("spark.barrier.sync.timeout", "1")
- .set(TEST_NO_STAGE_RETRY, true)
- .setMaster("local-cluster[4, 1, 1024]")
- .setAppName("test-cluster")
- sc = new SparkContext(conf)
+ initLocalClusterSparkContext()
+ sc.conf.set("spark.barrier.sync.timeout", "1")
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
@@ -102,12 +100,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
}
test("throw exception if barrier() call doesn't happen on every task") {
- val conf = new SparkConf()
- .set("spark.barrier.sync.timeout", "1")
- .set(TEST_NO_STAGE_RETRY, true)
- .setMaster("local-cluster[4, 1, 1024]")
- .setAppName("test-cluster")
- sc = new SparkContext(conf)
+ initLocalClusterSparkContext()
+ sc.conf.set("spark.barrier.sync.timeout", "1")
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
@@ -125,12 +119,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
}
test("throw exception if the number of barrier() calls are not the same on every task") {
- val conf = new SparkConf()
- .set("spark.barrier.sync.timeout", "1")
- .set(TEST_NO_STAGE_RETRY, true)
- .setMaster("local-cluster[4, 1, 1024]")
- .setAppName("test-cluster")
- sc = new SparkContext(conf)
+ initLocalClusterSparkContext()
+ sc.conf.set("spark.barrier.sync.timeout", "1")
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
@@ -156,10 +146,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
assert(error.contains("within 1 second(s)"))
}
-
- def testBarrierTaskKilled(sc: SparkContext, interruptOnCancel: Boolean): Unit = {
- sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, interruptOnCancel.toString)
-
+ def testBarrierTaskKilled(interruptOnKill: Boolean): Unit = {
withTempDir { dir =>
val killedFlagFile = "barrier.task.killed"
val rdd = sc.makeRDD(Seq(0, 1), 2)
@@ -181,12 +168,15 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
val listener = new SparkListener {
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
- new Thread {
- override def run: Unit = {
- Thread.sleep(1000)
- sc.killTaskAttempt(taskStart.taskInfo.taskId, interruptThread = false)
- }
- }.start()
+ val partitionId = taskStart.taskInfo.index
+ if (partitionId == 0) {
+ new Thread {
+ override def run: Unit = {
+ Thread.sleep(1000)
+ sc.killTaskAttempt(taskStart.taskInfo.taskId, interruptThread = interruptOnKill)
+ }
+ }.start()
+ }
}
}
sc.addSparkListener(listener)
@@ -201,15 +191,13 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
}
}
- test("barrier task killed") {
- val conf = new SparkConf()
- .set("spark.barrier.sync.timeout", "1")
- .set(TEST_NO_STAGE_RETRY, true)
- .setMaster("local-cluster[4, 1, 1024]")
- .setAppName("test-cluster")
- sc = new SparkContext(conf)
+ test("barrier task killed, no interrupt") {
+ initLocalClusterSparkContext()
+ testBarrierTaskKilled(interruptOnKill = false)
+ }
- testBarrierTaskKilled(sc, true)
- testBarrierTaskKilled(sc, false)
+ test("barrier task killed, interrupt") {
+ initLocalClusterSparkContext()
+ testBarrierTaskKilled(interruptOnKill = true)
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala
index 0fe0e5b78233c..246d4b2f56ec9 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala
@@ -16,8 +16,6 @@
*/
package org.apache.spark.scheduler
-import scala.concurrent.duration._
-
import org.apache.spark._
import org.apache.spark.internal.config
import org.apache.spark.internal.config.Tests._
diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
index 93a88cc30a20c..a1671a58f0d9b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
@@ -21,7 +21,7 @@ import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.{never, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.scalatest.BeforeAndAfterEach
-import org.scalatest.mockito.MockitoSugar
+import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark._
import org.apache.spark.internal.config
@@ -437,7 +437,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
}
test("check blacklist configuration invariants") {
- val conf = new SparkConf().setMaster("yarn-cluster")
+ val conf = new SparkConf().setMaster("yarn").set(config.SUBMIT_DEPLOY_MODE, "cluster")
Seq(
(2, 2),
(2, 3)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
index 3edbbeb9c08f1..61522145f8868 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.scheduler
import java.util.Properties
import java.util.concurrent.atomic.AtomicBoolean
-import scala.collection.immutable
import scala.collection.mutable
import scala.concurrent.duration._
import scala.language.postfixOps
@@ -29,7 +28,7 @@ import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.when
import org.mockito.invocation.InvocationOnMock
import org.scalatest.concurrent.Eventually
-import org.scalatest.mockito.MockitoSugar._
+import org.scalatestplus.mockito.MockitoSugar._
import org.apache.spark._
import org.apache.spark.internal.config._
diff --git a/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala b/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala
index 1be2e2a067115..46e5e6f97b1f1 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala
@@ -111,7 +111,7 @@ class CustomShuffledRDD[K, V, C](
.asInstanceOf[Iterator[(K, C)]]
}
- override def clearDependencies() {
+ override def clearDependencies(): Unit = {
super.clearDependencies()
dependency = null
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index cd854c379b08a..bd0a35af206af 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -151,7 +151,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
taskSets += taskSet
}
- override def cancelTasks(stageId: Int, interruptThread: Boolean) {
+ override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {
cancelledStages += stageId
}
override def killTaskAttempt(
@@ -172,34 +172,66 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
override def applicationAttemptId(): Option[String] = None
}
- /** Length of time to wait while draining listener events. */
- val WAIT_TIMEOUT_MILLIS = 10000
-
- val submittedStageInfos = new HashSet[StageInfo]
- val successfulStages = new HashSet[Int]
- val failedStages = new ArrayBuffer[Int]
- val stageByOrderOfExecution = new ArrayBuffer[Int]
- val endedTasks = new HashSet[Long]
- val sparkListener = new SparkListener() {
- override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
- submittedStageInfos += stageSubmitted.stageInfo
+ /**
+ * Listeners which records some information to verify in UTs. Getter-kind methods in this class
+ * ensures the value is returned after ensuring there's no event to process, as well as the
+ * value is immutable: prevent showing odd result by race condition.
+ */
+ class EventInfoRecordingListener extends SparkListener {
+ private val _submittedStageInfos = new HashSet[StageInfo]
+ private val _successfulStages = new HashSet[Int]
+ private val _failedStages = new ArrayBuffer[Int]
+ private val _stageByOrderOfExecution = new ArrayBuffer[Int]
+ private val _endedTasks = new HashSet[Long]
+
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
+ _submittedStageInfos += stageSubmitted.stageInfo
}
- override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
val stageInfo = stageCompleted.stageInfo
- stageByOrderOfExecution += stageInfo.stageId
+ _stageByOrderOfExecution += stageInfo.stageId
if (stageInfo.failureReason.isEmpty) {
- successfulStages += stageInfo.stageId
+ _successfulStages += stageInfo.stageId
} else {
- failedStages += stageInfo.stageId
+ _failedStages += stageInfo.stageId
}
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
- endedTasks += taskEnd.taskInfo.taskId
+ _endedTasks += taskEnd.taskInfo.taskId
+ }
+
+ def submittedStageInfos: Set[StageInfo] = {
+ waitForListeners()
+ _submittedStageInfos.toSet
+ }
+
+ def successfulStages: Set[Int] = {
+ waitForListeners()
+ _successfulStages.toSet
}
+
+ def failedStages: List[Int] = {
+ waitForListeners()
+ _failedStages.toList
+ }
+
+ def stageByOrderOfExecution: List[Int] = {
+ waitForListeners()
+ _stageByOrderOfExecution.toList
+ }
+
+ def endedTasks: Set[Long] = {
+ waitForListeners()
+ _endedTasks.toSet
+ }
+
+ private def waitForListeners(): Unit = sc.listenerBus.waitUntilEmpty()
}
+ var sparkListener: EventInfoRecordingListener = null
+
var mapOutputTracker: MapOutputTrackerMaster = null
var broadcastManager: BroadcastManager = null
var securityMgr: SecurityManager = null
@@ -220,7 +252,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
getOrElse(Seq())
}.toIndexedSeq
}
- override def removeExecutor(execId: String) {
+ override def removeExecutor(execId: String): Unit = {
// don't need to propagate to the driver, which we don't have
}
}
@@ -248,10 +280,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
private def init(testConf: SparkConf): Unit = {
sc = new SparkContext("local[2]", "DAGSchedulerSuite", testConf)
- submittedStageInfos.clear()
- successfulStages.clear()
- failedStages.clear()
- endedTasks.clear()
+ sparkListener = new EventInfoRecordingListener
failure = null
sc.addSparkListener(sparkListener)
taskSets.clear()
@@ -287,7 +316,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
}
}
- override def afterAll() {
+ override def afterAll(): Unit = {
super.afterAll()
}
@@ -304,7 +333,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
* After processing the event, submit waiting stages as is done on most iterations of the
* DAGScheduler event loop.
*/
- private def runEvent(event: DAGSchedulerEvent) {
+ private def runEvent(event: DAGSchedulerEvent): Unit = {
dagEventProcessLoopTester.post(event)
}
@@ -317,7 +346,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
it.next.asInstanceOf[Tuple2[_, _]]._1
/** Send the given CompletionEvent messages for the tasks in the TaskSet. */
- private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
+ private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]): Unit = {
assert(taskSet.tasks.size >= results.size)
for ((result, i) <- results.zipWithIndex) {
if (i < taskSet.tasks.size) {
@@ -329,7 +358,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
private def completeWithAccumulator(
accumId: Long,
taskSet: TaskSet,
- results: Seq[(TaskEndReason, Any)]) {
+ results: Seq[(TaskEndReason, Any)]): Unit = {
assert(taskSet.tasks.size >= results.size)
for ((result, i) <- results.zipWithIndex) {
if (i < taskSet.tasks.size) {
@@ -364,19 +393,18 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
}
/** Sends TaskSetFailed to the scheduler. */
- private def failed(taskSet: TaskSet, message: String) {
+ private def failed(taskSet: TaskSet, message: String): Unit = {
runEvent(TaskSetFailed(taskSet, message, None))
}
/** Sends JobCancelled to the DAG scheduler. */
- private def cancel(jobId: Int) {
+ private def cancel(jobId: Int): Unit = {
runEvent(JobCancelled(jobId, None))
}
test("[SPARK-3353] parent stage should have lower stage id") {
- stageByOrderOfExecution.clear()
sc.parallelize(1 to 10).map(x => (x, x)).reduceByKey(_ + _, 4).count()
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ val stageByOrderOfExecution = sparkListener.stageByOrderOfExecution
assert(stageByOrderOfExecution.length === 2)
assert(stageByOrderOfExecution(0) < stageByOrderOfExecution(1))
}
@@ -456,18 +484,22 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// map stage1 completes successfully, with one task on each executor
complete(taskSets(0), Seq(
(Success,
- MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))),
+ MapStatus(
+ BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), mapTaskId = 5)),
(Success,
- MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))),
- (Success, makeMapStatus("hostB", 1))
+ MapStatus(
+ BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), mapTaskId = 6)),
+ (Success, makeMapStatus("hostB", 1, mapTaskId = 7))
))
// map stage2 completes successfully, with one task on each executor
complete(taskSets(1), Seq(
(Success,
- MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))),
+ MapStatus(
+ BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), mapTaskId = 8)),
(Success,
- MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))),
- (Success, makeMapStatus("hostB", 1))
+ MapStatus(
+ BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), mapTaskId = 9)),
+ (Success, makeMapStatus("hostB", 1, mapTaskId = 10))
))
// make sure our test setup is correct
val initialMapStatus1 = mapOutputTracker.shuffleStatuses(firstShuffleId).mapStatuses
@@ -475,16 +507,19 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assert(initialMapStatus1.count(_ != null) === 3)
assert(initialMapStatus1.map{_.location.executorId}.toSet ===
Set("exec-hostA1", "exec-hostA2", "exec-hostB"))
+ assert(initialMapStatus1.map{_.mapTaskId}.toSet === Set(5, 6, 7))
val initialMapStatus2 = mapOutputTracker.shuffleStatuses(secondShuffleId).mapStatuses
// val initialMapStatus1 = mapOutputTracker.mapStatuses.get(0).get
assert(initialMapStatus2.count(_ != null) === 3)
assert(initialMapStatus2.map{_.location.executorId}.toSet ===
Set("exec-hostA1", "exec-hostA2", "exec-hostB"))
+ assert(initialMapStatus2.map{_.mapTaskId}.toSet === Set(8, 9, 10))
// reduce stage fails with a fetch failure from one host
complete(taskSets(2), Seq(
- (FetchFailed(BlockManagerId("exec-hostA2", "hostA", 12345), firstShuffleId, 0, 0, "ignored"),
+ (FetchFailed(BlockManagerId("exec-hostA2", "hostA", 12345),
+ firstShuffleId, 0L, 0, 0, "ignored"),
null)
))
@@ -619,9 +654,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
submit(unserializableRdd, Array(0))
assert(failure.getMessage.startsWith(
"Job aborted due to stage failure: Task not serializable:"))
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
- assert(failedStages.contains(0))
- assert(failedStages.size === 1)
+ assert(sparkListener.failedStages === Seq(0))
assertDataStructuresEmpty()
}
@@ -629,9 +662,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
submit(new MyRDD(sc, 1, Nil), Array(0))
failed(taskSets(0), "some failure")
assert(failure.getMessage === "Job aborted due to stage failure: some failure")
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
- assert(failedStages.contains(0))
- assert(failedStages.size === 1)
+ assert(sparkListener.failedStages === Seq(0))
assertDataStructuresEmpty()
}
@@ -640,9 +671,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
val jobId = submit(rdd, Array(0))
cancel(jobId)
assert(failure.getMessage === s"Job $jobId cancelled ")
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
- assert(failedStages.contains(0))
- assert(failedStages.size === 1)
+ assert(sparkListener.failedStages === Seq(0))
assertDataStructuresEmpty()
}
@@ -657,7 +686,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
override def submitTasks(taskSet: TaskSet): Unit = {
taskSets += taskSet
}
- override def cancelTasks(stageId: Int, interruptThread: Boolean) {
+ override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {
throw new UnsupportedOperationException
}
override def killTaskAttempt(
@@ -700,9 +729,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assert(results === Map(0 -> 42))
assertDataStructuresEmpty()
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
- assert(failedStages.isEmpty)
- assert(successfulStages.contains(0))
+ assert(sparkListener.failedStages.isEmpty)
+ assert(sparkListener.successfulStages.contains(0))
}
test("run trivial shuffle") {
@@ -733,7 +761,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// the 2nd ResultTask failed
complete(taskSets(1), Seq(
(Success, 42),
- (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null)))
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"), null)))
// this will get called
// blockManagerMaster.removeExecutor("exec-hostA")
// ask the scheduler to try it again
@@ -815,18 +843,18 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
val testRdd = new MyRDD(sc, 0, Nil)
val waiter = scheduler.submitJob(testRdd, func, Seq.empty, CallSite.empty,
resultHandler, properties)
- sc.listenerBus.waitUntilEmpty(1000L)
+ sc.listenerBus.waitUntilEmpty()
assert(assertionError.get() === null)
}
// Helper function to validate state when creating tests for task failures
- private def checkStageId(stageId: Int, attempt: Int, stageAttempt: TaskSet) {
+ private def checkStageId(stageId: Int, attempt: Int, stageAttempt: TaskSet): Unit = {
assert(stageAttempt.stageId === stageId)
assert(stageAttempt.stageAttemptId == attempt)
}
// Helper functions to extract commonly used code in Fetch Failure test cases
- private def setupStageAbortTest(sc: SparkContext) {
+ private def setupStageAbortTest(sc: SparkContext): Unit = {
sc.listenerBus.addToSharedQueue(new EndListener())
ended = false
jobResult = null
@@ -880,7 +908,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
val stageAttempt = taskSets.last
checkStageId(stageId, attemptIdx, stageAttempt)
complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map { case (task, idx) =>
- (FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0, idx, "ignored"), null)
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0L, 0, idx, "ignored"), null)
}.toSeq)
}
@@ -933,7 +961,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
completeNextResultStageWithSuccess(1, 1)
// Confirm job finished successfully
- sc.listenerBus.waitUntilEmpty(1000)
+ sc.listenerBus.waitUntilEmpty()
assert(ended)
assert(results === (0 until parts).map { idx => idx -> 42 }.toMap)
assertDataStructuresEmpty()
@@ -970,7 +998,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
} else {
// Stage should have been aborted and removed from running stages
assertDataStructuresEmpty()
- sc.listenerBus.waitUntilEmpty(1000)
+ sc.listenerBus.waitUntilEmpty()
assert(ended)
jobResult match {
case JobFailed(reason) =>
@@ -1092,7 +1120,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
completeNextResultStageWithSuccess(2, 1)
assertDataStructuresEmpty()
- sc.listenerBus.waitUntilEmpty(1000)
+ sc.listenerBus.waitUntilEmpty()
assert(ended)
assert(results === Map(0 -> 42))
}
@@ -1113,19 +1141,17 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// The first result task fails, with a fetch failure for the output from the first mapper.
runEvent(makeCompletionEvent(
taskSets(1).tasks(0),
- FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"),
null))
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
- assert(failedStages.contains(1))
+ assert(sparkListener.failedStages.contains(1))
// The second ResultTask fails, with a fetch failure for the output from the second mapper.
runEvent(makeCompletionEvent(
taskSets(1).tasks(0),
- FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1L, 1, 1, "ignored"),
null))
// The SparkListener should not receive redundant failure events.
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
- assert(failedStages.size == 1)
+ assert(sparkListener.failedStages.size === 1)
}
test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by FetchFailure") {
@@ -1142,7 +1168,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// The first result task fails, with a fetch failure for the output from the first mapper.
runEvent(makeCompletionEvent(
taskSets(1).tasks(0),
- FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"),
null))
assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(0, 1)))
@@ -1153,7 +1179,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// Complete the result stage.
completeNextResultStageWithSuccess(1, 1)
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ sc.listenerBus.waitUntilEmpty()
assertDataStructuresEmpty()
}
@@ -1172,7 +1198,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
taskSets(0).tasks(1),
TaskKilled("test"),
null))
- assert(failedStages === Seq(0))
+ assert(sparkListener.failedStages === Seq(0))
assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(0, 1)))
scheduler.resubmitFailedStages()
@@ -1182,7 +1208,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// Complete the result stage.
completeNextResultStageWithSuccess(1, 0)
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ sc.listenerBus.waitUntilEmpty()
assertDataStructuresEmpty()
}
@@ -1208,7 +1234,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
null))
// Assert the stage has been cancelled.
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ sc.listenerBus.waitUntilEmpty()
assert(failure.getMessage.startsWith("Job aborted due to stage failure: Could not recover " +
"from a failed barrier ResultStage."))
}
@@ -1226,11 +1252,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
val mapStageId = 0
def countSubmittedMapStageAttempts(): Int = {
- submittedStageInfos.count(_.stageId == mapStageId)
+ sparkListener.submittedStageInfos.count(_.stageId == mapStageId)
}
// The map stage should have been submitted.
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 1)
complete(taskSets(0), Seq(
@@ -1245,14 +1270,12 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// The first result task fails, with a fetch failure for the output from the first mapper.
runEvent(makeCompletionEvent(
taskSets(1).tasks(0),
- FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"),
null))
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
- assert(failedStages.contains(1))
+ assert(sparkListener.failedStages.contains(1))
// Trigger resubmission of the failed map stage.
runEvent(ResubmitFailedStages)
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
// Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
assert(countSubmittedMapStageAttempts() === 2)
@@ -1260,7 +1283,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// The second ResultTask fails, with a fetch failure for the output from the second mapper.
runEvent(makeCompletionEvent(
taskSets(1).tasks(1),
- FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
+ FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1L, 1, 1, "ignored"),
null))
// Another ResubmitFailedStages event should not result in another attempt for the map
@@ -1269,7 +1292,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// shouldn't effect anything -- our calling it just makes *SURE* it gets called between the
// desired event and our check.
runEvent(ResubmitFailedStages)
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 2)
}
@@ -1287,14 +1309,13 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
submit(reduceRdd, Array(0, 1))
def countSubmittedReduceStageAttempts(): Int = {
- submittedStageInfos.count(_.stageId == 1)
+ sparkListener.submittedStageInfos.count(_.stageId == 1)
}
def countSubmittedMapStageAttempts(): Int = {
- submittedStageInfos.count(_.stageId == 0)
+ sparkListener.submittedStageInfos.count(_.stageId == 0)
}
// The map stage should have been submitted.
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 1)
// Complete the map stage.
@@ -1303,13 +1324,12 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
(Success, makeMapStatus("hostB", 2))))
// The reduce stage should have been submitted.
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedReduceStageAttempts() === 1)
// The first result task fails, with a fetch failure for the output from the first mapper.
runEvent(makeCompletionEvent(
taskSets(1).tasks(0),
- FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"),
null))
// Trigger resubmission of the failed map stage and finish the re-started map task.
@@ -1318,14 +1338,13 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// Because the map stage finished, another attempt for the reduce stage should have been
// submitted, resulting in 2 total attempts for each the map and the reduce stage.
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 2)
assert(countSubmittedReduceStageAttempts() === 2)
// A late FetchFailed arrives from the second task in the original reduce stage.
runEvent(makeCompletionEvent(
taskSets(1).tasks(1),
- FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
+ FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1L, 1, 1, "ignored"),
null))
// Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because
@@ -1348,10 +1367,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
runEvent(makeCompletionEvent(
taskSets(0).tasks(1), Success, 42,
Seq.empty, Array.empty, createFakeTaskInfoWithId(1)))
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
// verify stage exists
assert(scheduler.stageIdToStage.contains(0))
- assert(endedTasks.size == 2)
+ assert(sparkListener.endedTasks.size === 2)
// finish other 2 tasks
runEvent(makeCompletionEvent(
@@ -1360,8 +1378,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
runEvent(makeCompletionEvent(
taskSets(0).tasks(3), Success, 42,
Seq.empty, Array.empty, createFakeTaskInfoWithId(3)))
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
- assert(endedTasks.size == 4)
+ assert(sparkListener.endedTasks.size === 4)
// verify the stage is done
assert(!scheduler.stageIdToStage.contains(0))
@@ -1371,15 +1388,13 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
runEvent(makeCompletionEvent(
taskSets(0).tasks(3), Success, 42,
Seq.empty, Array.empty, createFakeTaskInfoWithId(5)))
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
- assert(endedTasks.size == 5)
+ assert(sparkListener.endedTasks.size === 5)
// make sure non successful tasks also send out event
runEvent(makeCompletionEvent(
taskSets(0).tasks(3), UnknownReason, 42,
Seq.empty, Array.empty, createFakeTaskInfoWithId(6)))
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
- assert(endedTasks.size == 6)
+ assert(sparkListener.endedTasks.size === 6)
}
test("ignore late map task completions") {
@@ -1452,8 +1467,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// Listener bus should get told about the map stage failing, but not the reduce stage
// (since the reduce stage hasn't been started yet).
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
- assert(failedStages.toSet === Set(0))
+ assert(sparkListener.failedStages.toSet === Set(0))
assertDataStructuresEmpty()
}
@@ -1525,7 +1539,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
runEvent(ExecutorLost("exec-hostA", ExecutorKilled))
runEvent(makeCompletionEvent(
taskSets(1).tasks(0),
- FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"),
+ FetchFailed(null, firstShuffleId, 2L, 2, 0, "Fetch failed"),
null))
// so we resubmit stage 0, which completes happily
@@ -1681,7 +1695,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// listener for all jobs, and here we want to capture the failure for each job separately.
class FailureRecordingJobListener() extends JobListener {
var failureMessage: String = _
- override def taskSucceeded(index: Int, result: Any) {}
+ override def taskSucceeded(index: Int, result: Any): Unit = {}
override def jobFailed(exception: Exception): Unit = { failureMessage = exception.getMessage }
}
val listener1 = new FailureRecordingJobListener()
@@ -1696,9 +1710,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assert(cancelledStages.toSet === Set(0, 2))
// Make sure the listeners got told about both failed stages.
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
- assert(successfulStages.isEmpty)
- assert(failedStages.toSet === Set(0, 2))
+ assert(sparkListener.successfulStages.isEmpty)
+ assert(sparkListener.failedStages.toSet === Set(0, 2))
assert(listener1.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage")
assert(listener2.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage")
@@ -1785,7 +1798,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// lets say there is a fetch failure in this task set, which makes us go back and
// run stage 0, attempt 1
complete(taskSets(1), Seq(
- (FetchFailed(makeBlockManagerId("hostA"), shuffleDep1.shuffleId, 0, 0, "ignored"), null)))
+ (FetchFailed(makeBlockManagerId("hostA"),
+ shuffleDep1.shuffleId, 0L, 0, 0, "ignored"), null)))
scheduler.resubmitFailedStages()
// stage 0, attempt 1 should have the properties of job2
@@ -1866,7 +1880,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
(Success, makeMapStatus("hostC", 1))))
// fail the third stage because hostA went down
complete(taskSets(2), Seq(
- (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null)))
+ (FetchFailed(makeBlockManagerId("hostA"),
+ shuffleDepTwo.shuffleId, 0L, 0, 0, "ignored"), null)))
// TODO assert this:
// blockManagerMaster.removeExecutor("exec-hostA")
// have DAGScheduler try again
@@ -1897,7 +1912,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
(Success, makeMapStatus("hostB", 1))))
// pretend stage 2 failed because hostA went down
complete(taskSets(2), Seq(
- (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null)))
+ (FetchFailed(makeBlockManagerId("hostA"),
+ shuffleDepTwo.shuffleId, 0L, 0, 0, "ignored"), null)))
// TODO assert this:
// blockManagerMaster.removeExecutor("exec-hostA")
// DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
@@ -2258,7 +2274,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
submit(reduceRdd, Array(0, 1))
complete(taskSets(1), Seq(
(Success, 42),
- (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null)))
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"), null)))
// Ask the scheduler to try it again; TaskSet 2 will rerun the map task that we couldn't fetch
// from, then TaskSet 3 will run the reduce stage
scheduler.resubmitFailedStages()
@@ -2317,7 +2333,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assert(taskSets(1).stageId === 1)
complete(taskSets(1), Seq(
(Success, makeMapStatus("hostA", rdd2.partitions.length)),
- (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null)))
+ (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0L, 0, 0, "ignored"), null)))
scheduler.resubmitFailedStages()
assert(listener2.results.size === 0) // Second stage listener should not have a result yet
@@ -2343,7 +2359,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assert(taskSets(4).stageId === 2)
complete(taskSets(4), Seq(
(Success, 52),
- (FetchFailed(makeBlockManagerId("hostD"), dep2.shuffleId, 0, 0, "ignored"), null)))
+ (FetchFailed(makeBlockManagerId("hostD"), dep2.shuffleId, 0L, 0, 0, "ignored"), null)))
scheduler.resubmitFailedStages()
// TaskSet 5 will rerun stage 1's lost task, then TaskSet 6 will rerun stage 2
@@ -2381,7 +2397,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assert(taskSets(1).stageId === 1)
complete(taskSets(1), Seq(
(Success, makeMapStatus("hostC", rdd2.partitions.length)),
- (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null)))
+ (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0L, 0, 0, "ignored"), null)))
scheduler.resubmitFailedStages()
// Stage1 listener should not have a result yet
assert(listener2.results.size === 0)
@@ -2516,7 +2532,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
rdd1.map {
case (x, _) if (x == 1) =>
throw new FetchFailedException(
- BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test")
+ BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0L, 0, 0, "test")
case (x, _) => x
}.count()
}
@@ -2529,7 +2545,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
rdd1.map {
case (x, _) if (x == 1) && FailThisAttempt._fail.getAndSet(false) =>
throw new FetchFailedException(
- BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test")
+ BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0L, 0, 0, "test")
}
}
@@ -2583,7 +2599,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assert(taskSets(1).stageId === 1 && taskSets(1).stageAttemptId === 0)
runEvent(makeCompletionEvent(
taskSets(1).tasks(0),
- FetchFailed(makeBlockManagerId("hostA"), shuffleIdA, 0, 0,
+ FetchFailed(makeBlockManagerId("hostA"), shuffleIdA, 0L, 0, 0,
"Fetch failure of task: stageId=1, stageAttempt=0, partitionId=0"),
result = null))
@@ -2659,7 +2675,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
sc.parallelize(1 to tasks, tasks).foreach { _ =>
accum.add(1L)
}
- sc.listenerBus.waitUntilEmpty(1000)
+ sc.listenerBus.waitUntilEmpty()
assert(foundCount.get() === tasks)
}
}
@@ -2672,11 +2688,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
val mapStageId = 0
def countSubmittedMapStageAttempts(): Int = {
- submittedStageInfos.count(_.stageId == mapStageId)
+ sparkListener.submittedStageInfos.count(_.stageId == mapStageId)
}
// The map stage should have been submitted.
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 1)
// The first map task fails with TaskKilled.
@@ -2684,7 +2699,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
taskSets(0).tasks(0),
TaskKilled("test"),
null))
- assert(failedStages === Seq(0))
+ assert(sparkListener.failedStages === Seq(0))
// The second map task fails with TaskKilled.
runEvent(makeCompletionEvent(
@@ -2694,7 +2709,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// Trigger resubmission of the failed map stage.
runEvent(ResubmitFailedStages)
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
// Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
assert(countSubmittedMapStageAttempts() === 2)
@@ -2708,11 +2722,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
val mapStageId = 0
def countSubmittedMapStageAttempts(): Int = {
- submittedStageInfos.count(_.stageId == mapStageId)
+ sparkListener.submittedStageInfos.count(_.stageId == mapStageId)
}
// The map stage should have been submitted.
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 1)
// The first map task fails with TaskKilled.
@@ -2720,11 +2733,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
taskSets(0).tasks(0),
TaskKilled("test"),
null))
- assert(failedStages === Seq(0))
+ assert(sparkListener.failedStages === Seq(0))
// Trigger resubmission of the failed map stage.
runEvent(ResubmitFailedStages)
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
// Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
assert(countSubmittedMapStageAttempts() === 2)
@@ -2737,11 +2749,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// The second map task failure doesn't trigger stage retry.
runEvent(ResubmitFailedStages)
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 2)
}
- test("SPARK-23207: retry all the succeeding stages when the map stage is indeterminate") {
+ private def constructIndeterminateStageFetchFailed(): (Int, Int) = {
val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true)
val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2))
@@ -2769,14 +2780,152 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// The first task of the final stage failed with fetch failure
runEvent(makeCompletionEvent(
taskSets(2).tasks(0),
- FetchFailed(makeBlockManagerId("hostC"), shuffleId2, 0, 0, "ignored"),
+ FetchFailed(makeBlockManagerId("hostC"), shuffleId2, 0L, 0, 0, "ignored"),
+ null))
+ (shuffleId1, shuffleId2)
+ }
+
+ test("SPARK-25341: abort stage while using old fetch protocol") {
+ // reset the test context with using old fetch protocol
+ afterEach()
+ val conf = new SparkConf()
+ conf.set(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL.key, "true")
+ init(conf)
+ // Construct the scenario of indeterminate stage fetch failed.
+ constructIndeterminateStageFetchFailed()
+ // The job should fail because Spark can't rollback the shuffle map stage while
+ // using old protocol.
+ assert(failure != null && failure.getMessage.contains(
+ "Spark can only do this while using the new shuffle block fetching protocol"))
+ }
+
+ test("SPARK-25341: retry all the succeeding stages when the map stage is indeterminate") {
+ val (shuffleId1, shuffleId2) = constructIndeterminateStageFetchFailed()
+
+ // Check status for all failedStages
+ val failedStages = scheduler.failedStages.toSeq
+ assert(failedStages.map(_.id) == Seq(1, 2))
+ // Shuffle blocks of "hostC" is lost, so first task of the `shuffleMapRdd2` needs to retry.
+ assert(failedStages.collect {
+ case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 => stage
+ }.head.findMissingPartitions() == Seq(0))
+ // The result stage is still waiting for its 2 tasks to complete
+ assert(failedStages.collect {
+ case stage: ResultStage => stage
+ }.head.findMissingPartitions() == Seq(0, 1))
+
+ scheduler.resubmitFailedStages()
+
+ // The first task of the `shuffleMapRdd2` failed with fetch failure
+ runEvent(makeCompletionEvent(
+ taskSets(3).tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, "ignored"),
+ null))
+
+ val newFailedStages = scheduler.failedStages.toSeq
+ assert(newFailedStages.map(_.id) == Seq(0, 1))
+
+ scheduler.resubmitFailedStages()
+
+ // First shuffle map stage resubmitted and reran all tasks.
+ assert(taskSets(4).stageId == 0)
+ assert(taskSets(4).stageAttemptId == 1)
+ assert(taskSets(4).tasks.length == 2)
+
+ // Finish all stage.
+ complete(taskSets(4), Seq(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))))
+ assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty))
+
+ complete(taskSets(5), Seq(
+ (Success, makeMapStatus("hostC", 2)),
+ (Success, makeMapStatus("hostD", 2))))
+ assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty))
+
+ complete(taskSets(6), Seq((Success, 11), (Success, 12)))
+
+ // Job successful ended.
+ assert(results === Map(0 -> 11, 1 -> 12))
+ results.clear()
+ assertDataStructuresEmpty()
+ }
+
+ test("SPARK-25341: continuous indeterminate stage roll back") {
+ // shuffleMapRdd1/2/3 are all indeterminate.
+ val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true)
+ val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2))
+ val shuffleId1 = shuffleDep1.shuffleId
+
+ val shuffleMapRdd2 = new MyRDD(
+ sc, 2, List(shuffleDep1), tracker = mapOutputTracker, indeterminate = true)
+ val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2))
+ val shuffleId2 = shuffleDep2.shuffleId
+
+ val shuffleMapRdd3 = new MyRDD(
+ sc, 2, List(shuffleDep2), tracker = mapOutputTracker, indeterminate = true)
+ val shuffleDep3 = new ShuffleDependency(shuffleMapRdd3, new HashPartitioner(2))
+ val shuffleId3 = shuffleDep3.shuffleId
+ val finalRdd = new MyRDD(sc, 2, List(shuffleDep3), tracker = mapOutputTracker)
+
+ submit(finalRdd, Array(0, 1), properties = new Properties())
+
+ // Finish the first 2 shuffle map stages.
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))))
+ assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty))
+
+ complete(taskSets(1), Seq(
+ (Success, makeMapStatus("hostB", 2)),
+ (Success, makeMapStatus("hostD", 2))))
+ assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty))
+
+ // Executor lost on hostB, both of stage 0 and 1 should be reran.
+ runEvent(makeCompletionEvent(
+ taskSets(2).tasks(0),
+ FetchFailed(makeBlockManagerId("hostB"), shuffleId2, 0L, 0, 0, "ignored"),
null))
+ mapOutputTracker.removeOutputsOnHost("hostB")
- // The second shuffle map stage need to rerun, the job will abort for the indeterminate
- // stage rerun.
- // TODO: After we support re-generate shuffle file(SPARK-25341), this test will be extended.
- assert(failure != null && failure.getMessage
- .contains("Spark cannot rollback the ShuffleMapStage 1"))
+ assert(scheduler.failedStages.toSeq.map(_.id) == Seq(1, 2))
+ scheduler.resubmitFailedStages()
+
+ def checkAndCompleteRetryStage(
+ taskSetIndex: Int,
+ stageId: Int,
+ shuffleId: Int): Unit = {
+ assert(taskSets(taskSetIndex).stageId == stageId)
+ assert(taskSets(taskSetIndex).stageAttemptId == 1)
+ assert(taskSets(taskSetIndex).tasks.length == 2)
+ complete(taskSets(taskSetIndex), Seq(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))))
+ assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty))
+ }
+
+ // Check all indeterminate stage roll back.
+ checkAndCompleteRetryStage(3, 0, shuffleId1)
+ checkAndCompleteRetryStage(4, 1, shuffleId2)
+ checkAndCompleteRetryStage(5, 2, shuffleId3)
+
+ // Result stage success, all job ended.
+ complete(taskSets(6), Seq((Success, 11), (Success, 12)))
+ assert(results === Map(0 -> 11, 1 -> 12))
+ results.clear()
+ assertDataStructuresEmpty()
+ }
+
+ test("SPARK-29042: Sampled RDD with unordered input should be indeterminate") {
+ val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = false)
+
+ val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2))
+ val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+
+ assert(shuffleMapRdd2.outputDeterministicLevel == DeterministicLevel.UNORDERED)
+
+ val sampledRdd = shuffleMapRdd2.sample(true, 0.3, 1000L)
+ assert(sampledRdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE)
}
private def assertResultStageFailToRollback(mapRdd: MyRDD): Unit = {
@@ -2797,7 +2946,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// Fail the second task with FetchFailed.
runEvent(makeCompletionEvent(
taskSets.last.tasks(1),
- FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"),
null))
// The job should fail because Spark can't rollback the result stage.
@@ -2840,7 +2989,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// Fail the second task with FetchFailed.
runEvent(makeCompletionEvent(
taskSets.last.tasks(1),
- FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"),
null))
assert(failure == null, "job should not fail")
@@ -2887,33 +3036,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assert(latch.await(10, TimeUnit.SECONDS))
}
- test("SPARK-28699: abort stage if parent stage is indeterminate stage") {
- val shuffleMapRdd = new MyRDD(sc, 2, Nil, indeterminate = true)
-
- val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
- val shuffleId = shuffleDep.shuffleId
- val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
-
- submit(finalRdd, Array(0, 1))
-
- // Finish the first shuffle map stage.
- complete(taskSets(0), Seq(
- (Success, makeMapStatus("hostA", 2)),
- (Success, makeMapStatus("hostB", 2))))
- assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty))
-
- runEvent(makeCompletionEvent(
- taskSets(1).tasks(0),
- FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
- null))
-
- // Shuffle blocks of "hostA" is lost, so first task of the `shuffleMapRdd` needs to retry.
- // The result stage is still waiting for its 2 tasks to complete.
- // Because of shuffleMapRdd is indeterminate, this job will be abort.
- assert(failure != null && failure.getMessage
- .contains("Spark cannot rollback the ShuffleMapStage 0"))
- }
-
test("Completions in zombie tasksets update status of non-zombie taskset") {
val parts = 4
val shuffleMapRdd = new MyRDD(sc, parts, Nil)
@@ -2930,7 +3052,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// The second task of the shuffle map stage failed with FetchFailed.
runEvent(makeCompletionEvent(
taskSets(0).tasks(1),
- FetchFailed(makeBlockManagerId("hostB"), shuffleDep.shuffleId, 0, 0, "ignored"),
+ FetchFailed(makeBlockManagerId("hostB"), shuffleDep.shuffleId, 0L, 0, 0, "ignored"),
null))
scheduler.resubmitFailedStages()
@@ -2969,7 +3091,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
* Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
* Note that this checks only the host and not the executor ID.
*/
- private def assertLocations(taskSet: TaskSet, hosts: Seq[Seq[String]]) {
+ private def assertLocations(taskSet: TaskSet, hosts: Seq[Seq[String]]): Unit = {
assert(hosts.size === taskSet.tasks.size)
for ((taskLocs, expectedLocs) <- taskSet.tasks.map(_.preferredLocations).zip(hosts)) {
assert(taskLocs.map(_.host).toSet === expectedLocs.toSet)
@@ -3020,8 +3142,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
}
object DAGSchedulerSuite {
- def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus =
- MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes))
+ def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2, mapTaskId: Long = -1): MapStatus =
+ MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), mapTaskId)
def makeBlockManagerId(host: String): BlockManagerId =
BlockManagerId("exec-" + host, host, 12345)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
index a83ca594ee908..ae55d1915fa4a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
@@ -178,7 +178,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit
*/
private def testEventLogging(
compressionCodec: Option[String] = None,
- extraConf: Map[String, String] = Map()) {
+ extraConf: Map[String, String] = Map()): Unit = {
val conf = getLoggingConf(testDirPath, compressionCodec)
extraConf.foreach { case (k, v) => conf.set(k, v) }
val logName = compressionCodec.map("test-" + _).getOrElse("test")
@@ -218,7 +218,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit
* Test end-to-end event logging functionality in an application.
* This runs a simple Spark job and asserts that the expected events are logged when expected.
*/
- private def testApplicationEventLogging(compressionCodec: Option[String] = None) {
+ private def testApplicationEventLogging(compressionCodec: Option[String] = None): Unit = {
// Set defaultFS to something that would cause an exception, to make sure we don't run
// into SPARK-6688.
val conf = getLoggingConf(testDirPath, compressionCodec)
@@ -284,7 +284,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit
* from SparkListenerTaskEnd events for tasks belonging to the stage are
* logged in a StageExecutorMetrics event for each executor at stage completion.
*/
- private def testStageExecutorMetricsEventLogging() {
+ private def testStageExecutorMetricsEventLogging(): Unit = {
val conf = getLoggingConf(testDirPath, None)
val logName = "stageExecutorMetrics-test"
val eventLogger = new EventLoggingListener(logName, None, testDirPath.toUri(), conf)
@@ -621,19 +621,19 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit
var jobEnded = false
var appEnded = false
- override def onJobStart(jobStart: SparkListenerJobStart) {
+ override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
jobStarted = true
}
- override def onJobEnd(jobEnd: SparkListenerJobEnd) {
+ override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
jobEnded = true
}
- override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) {
+ override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
appEnded = true
}
- def assertAllCallbacksInvoked() {
+ def assertAllCallbacksInvoked(): Unit = {
assert(jobStarted, "JobStart callback not invoked!")
assert(jobEnded, "JobEnd callback not invoked!")
assert(appEnded, "ApplicationEnd callback not invoked!")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
index 73e88c4a0fda6..4e71ec1ea7b37 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
@@ -67,9 +67,9 @@ private class DummyExternalClusterManager extends ExternalClusterManager {
private class DummySchedulerBackend extends SchedulerBackend {
var initialized = false
- def start() {}
- def stop() {}
- def reviveOffers() {}
+ def start(): Unit = {}
+ def stop(): Unit = {}
+ def reviveOffers(): Unit = {}
def defaultParallelism(): Int = 1
def maxNumConcurrentTasks(): Int = 0
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index b29d32f7b35c5..8cb6268f85d36 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -42,22 +42,30 @@ object FakeTask {
* locations for each task (given as varargs) if this sequence is not empty.
*/
def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
- createTaskSet(numTasks, stageAttemptId = 0, prefLocs: _*)
+ createTaskSet(numTasks, stageId = 0, stageAttemptId = 0, priority = 0, prefLocs: _*)
}
- def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
- createTaskSet(numTasks, stageId = 0, stageAttemptId, prefLocs: _*)
+ def createTaskSet(
+ numTasks: Int,
+ stageId: Int,
+ stageAttemptId: Int,
+ prefLocs: Seq[TaskLocation]*): TaskSet = {
+ createTaskSet(numTasks, stageId, stageAttemptId, priority = 0, prefLocs: _*)
}
- def createTaskSet(numTasks: Int, stageId: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*):
- TaskSet = {
+ def createTaskSet(
+ numTasks: Int,
+ stageId: Int,
+ stageAttemptId: Int,
+ priority: Int,
+ prefLocs: Seq[TaskLocation]*): TaskSet = {
if (prefLocs.size != 0 && prefLocs.size != numTasks) {
throw new IllegalArgumentException("Wrong number of task locations")
}
val tasks = Array.tabulate[Task[_]](numTasks) { i =>
new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil)
}
- new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null)
+ new TaskSet(tasks, stageId, stageAttemptId, priority = priority, null)
}
def createShuffleMapTaskSet(
@@ -65,6 +73,15 @@ object FakeTask {
stageId: Int,
stageAttemptId: Int,
prefLocs: Seq[TaskLocation]*): TaskSet = {
+ createShuffleMapTaskSet(numTasks, stageId, stageAttemptId, priority = 0, prefLocs: _*)
+ }
+
+ def createShuffleMapTaskSet(
+ numTasks: Int,
+ stageId: Int,
+ stageAttemptId: Int,
+ priority: Int,
+ prefLocs: Seq[TaskLocation]*): TaskSet = {
if (prefLocs.size != 0 && prefLocs.size != numTasks) {
throw new IllegalArgumentException("Wrong number of task locations")
}
@@ -74,17 +91,18 @@ object FakeTask {
}, prefLocs(i), new Properties,
SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array())
}
- new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null)
+ new TaskSet(tasks, stageId, stageAttemptId, priority = priority, null)
}
def createBarrierTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
- createBarrierTaskSet(numTasks, stageId = 0, stageAttempId = 0, prefLocs: _*)
+ createBarrierTaskSet(numTasks, stageId = 0, stageAttemptId = 0, priority = 0, prefLocs: _*)
}
def createBarrierTaskSet(
numTasks: Int,
stageId: Int,
- stageAttempId: Int,
+ stageAttemptId: Int,
+ priority: Int,
prefLocs: Seq[TaskLocation]*): TaskSet = {
if (prefLocs.size != 0 && prefLocs.size != numTasks) {
throw new IllegalArgumentException("Wrong number of task locations")
@@ -92,6 +110,6 @@ object FakeTask {
val tasks = Array.tabulate[Task[_]](numTasks) { i =>
new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil, isBarrier = true)
}
- new TaskSet(tasks, stageId, stageAttempId, priority = 0, null)
+ new TaskSet(tasks, stageId, stageAttemptId, priority = priority, null)
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
index c1e7fb9a1db16..700d9ebd76c0c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
@@ -61,7 +61,7 @@ class MapStatusSuite extends SparkFunSuite {
stddev <- Seq(0.0, 0.01, 0.5, 1.0)
) {
val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean)
- val status = MapStatus(BlockManagerId("a", "b", 10), sizes)
+ val status = MapStatus(BlockManagerId("a", "b", 10), sizes, -1)
val status1 = compressAndDecompressMapStatus(status)
for (i <- 0 until numSizes) {
if (sizes(i) != 0) {
@@ -75,7 +75,7 @@ class MapStatusSuite extends SparkFunSuite {
test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) {
val sizes = Array.fill[Long](2001)(150L)
- val status = MapStatus(null, sizes)
+ val status = MapStatus(null, sizes, -1)
assert(status.isInstanceOf[HighlyCompressedMapStatus])
assert(status.getSizeForBlock(10) === 150L)
assert(status.getSizeForBlock(50) === 150L)
@@ -87,10 +87,12 @@ class MapStatusSuite extends SparkFunSuite {
val sizes = Array.tabulate[Long](3000) { i => i.toLong }
val avg = sizes.sum / sizes.count(_ != 0)
val loc = BlockManagerId("a", "b", 10)
- val status = MapStatus(loc, sizes)
+ val mapTaskAttemptId = 5
+ val status = MapStatus(loc, sizes, mapTaskAttemptId)
val status1 = compressAndDecompressMapStatus(status)
assert(status1.isInstanceOf[HighlyCompressedMapStatus])
assert(status1.location == loc)
+ assert(status1.mapTaskId == mapTaskAttemptId)
for (i <- 0 until 3000) {
val estimate = status1.getSizeForBlock(i)
if (sizes(i) > 0) {
@@ -109,7 +111,7 @@ class MapStatusSuite extends SparkFunSuite {
val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold)
val avg = smallBlockSizes.sum / smallBlockSizes.length
val loc = BlockManagerId("a", "b", 10)
- val status = MapStatus(loc, sizes)
+ val status = MapStatus(loc, sizes, 5)
val status1 = compressAndDecompressMapStatus(status)
assert(status1.isInstanceOf[HighlyCompressedMapStatus])
assert(status1.location == loc)
@@ -165,7 +167,7 @@ class MapStatusSuite extends SparkFunSuite {
SparkEnv.set(env)
// Value of element in sizes is equal to the corresponding index.
val sizes = (0L to 2000L).toArray
- val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes)
+ val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes, 5)
val arrayStream = new ByteArrayOutputStream(102400)
val objectOutputStream = new ObjectOutputStream(arrayStream)
assert(status1.isInstanceOf[HighlyCompressedMapStatus])
diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala
index 848f702935536..7d063c3b3ac53 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala
@@ -22,7 +22,6 @@ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
import org.scalatest.time.{Seconds, Span}
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TaskContext}
-import org.apache.spark.util.Utils
/**
* Integration tests for the OutputCommitCoordinator.
diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
index d6964063c118e..05d9ec4861de9 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
@@ -254,7 +254,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
.reduceByKey { case (_, _) =>
val ctx = TaskContext.get()
if (ctx.stageAttemptNumber() == 0) {
- throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 1, 1, 1,
+ throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 1, 1L, 1, 1,
new Exception("Failure for test."))
} else {
ctx.stageId()
@@ -288,7 +288,7 @@ private case class OutputCommitFunctions(tempDirPath: String) {
// Mock output committer that simulates a failed commit (after commit is authorized)
private def failingOutputCommitter = new FakeOutputCommitter {
- override def commitTask(taskAttemptContext: TaskAttemptContext) {
+ override def commitTask(taskAttemptContext: TaskAttemptContext): Unit = {
throw new RuntimeException
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
index d65b5cbfc094e..55e7f5333c676 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.io._
import java.net.URI
+import java.nio.charset.StandardCharsets
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.ArrayBuffer
@@ -52,10 +53,11 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp
test("Simple replay") {
val logFilePath = getFilePath(testDir, "events.txt")
val fstream = fileSystem.create(logFilePath)
+ val fwriter = new OutputStreamWriter(fstream, StandardCharsets.UTF_8)
val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None,
125L, "Mickey", None)
val applicationEnd = SparkListenerApplicationEnd(1000L)
- Utils.tryWithResource(new PrintWriter(fstream)) { writer =>
+ Utils.tryWithResource(new PrintWriter(fwriter)) { writer =>
// scalastyle:off println
writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationStart))))
writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationEnd))))
@@ -87,8 +89,9 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp
test("Replay compressed inprogress log file succeeding on partial read") {
val buffered = new ByteArrayOutputStream
val codec = new LZ4CompressionCodec(new SparkConf())
- val compstream = codec.compressedOutputStream(buffered)
- Utils.tryWithResource(new PrintWriter(compstream)) { writer =>
+ val compstream = codec.compressedContinuousOutputStream(buffered)
+ val cwriter = new OutputStreamWriter(compstream, StandardCharsets.UTF_8)
+ Utils.tryWithResource(new PrintWriter(cwriter)) { writer =>
val applicationStart = SparkListenerApplicationStart("AppStarts", None,
125L, "Mickey", None)
@@ -134,10 +137,11 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp
test("Replay incompatible event log") {
val logFilePath = getFilePath(testDir, "incompatible.txt")
val fstream = fileSystem.create(logFilePath)
+ val fwriter = new OutputStreamWriter(fstream, StandardCharsets.UTF_8)
val applicationStart = SparkListenerApplicationStart("Incompatible App", None,
125L, "UserUsingIncompatibleVersion", None)
val applicationEnd = SparkListenerApplicationEnd(1000L)
- Utils.tryWithResource(new PrintWriter(fstream)) { writer =>
+ Utils.tryWithResource(new PrintWriter(fwriter)) { writer =>
// scalastyle:off println
writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationStart))))
writer.println("""{"Event":"UnrecognizedEventOnlyForTest","Timestamp":1477593059313}""")
@@ -184,7 +188,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp
* event to the corresponding event replayed from the event logs. This test makes the
* assumption that the event logging behavior is correct (tested in a separate suite).
*/
- private def testApplicationReplay(codecName: Option[String] = None) {
+ private def testApplicationReplay(codecName: Option[String] = None): Unit = {
val logDir = new File(testDir.getAbsolutePath, "test-replay")
// Here, it creates `Path` from the URI instead of the absolute path for the explicit file
// scheme so that the string representation of this `Path` has leading file scheme correctly.
@@ -242,7 +246,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp
private[scheduler] val loggedEvents = new ArrayBuffer[JValue]
- override def onEvent(event: SparkListenerEvent) {
+ override def onEvent(event: SparkListenerEvent): Unit = {
val eventJson = JsonProtocol.sparkEventToJson(event)
loggedEvents += eventJson
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
index 96706536fe53c..4f737c9499ad6 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
@@ -621,7 +621,7 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor
backend.taskSuccess(taskDescription, DAGSchedulerSuite.makeMapStatus("hostA", 10))
case (1, 0, 0) =>
val fetchFailed = FetchFailed(
- DAGSchedulerSuite.makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored")
+ DAGSchedulerSuite.makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored")
backend.taskFailed(taskDescription, fetchFailed)
case (1, _, partition) =>
backend.taskSuccess(taskDescription, 42 + partition)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 8903e1054f53d..f73ebd6a5b42d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -38,9 +38,6 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
import LiveListenerBus._
- /** Length of time to wait while draining listener events. */
- val WAIT_TIMEOUT_MILLIS = 10000
-
val jobCompletionTime = 1421191296660L
private val mockSparkContext: SparkContext = Mockito.mock(classOf[SparkContext])
@@ -65,7 +62,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
sc.listenerBus.addToSharedQueue(listener)
sc.listenerBus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded))
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ sc.listenerBus.waitUntilEmpty()
sc.stop()
assert(listener.sparkExSeen)
@@ -97,7 +94,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
// Starting listener bus should flush all buffered events
bus.start(mockSparkContext, mockMetricsSystem)
Mockito.verify(mockMetricsSystem).registerSource(bus.metrics)
- bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ bus.waitUntilEmpty()
assert(counter.count === 5)
assert(sharedQueueSize(bus) === 0)
assert(eventProcessingTimeCount(bus) === 5)
@@ -159,7 +156,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
assert(!drained)
new Thread("ListenerBusStopper") {
- override def run() {
+ override def run(): Unit = {
stopperStarted.release()
// stop() will block until notify() is called below
bus.stop()
@@ -223,7 +220,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
rdd2.setName("Target RDD")
rdd2.count()
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ sc.listenerBus.waitUntilEmpty()
listener.stageInfos.size should be {1}
val (stageInfo, taskInfoMetrics) = listener.stageInfos.head
@@ -248,7 +245,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
rdd3.setName("Trois")
rdd1.count()
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ sc.listenerBus.waitUntilEmpty()
listener.stageInfos.size should be {1}
val stageInfo1 = listener.stageInfos.keys.find(_.stageId == 0).get
stageInfo1.rddInfos.size should be {1} // ParallelCollectionRDD
@@ -257,7 +254,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
listener.stageInfos.clear()
rdd2.count()
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ sc.listenerBus.waitUntilEmpty()
listener.stageInfos.size should be {1}
val stageInfo2 = listener.stageInfos.keys.find(_.stageId == 1).get
stageInfo2.rddInfos.size should be {3}
@@ -266,7 +263,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
listener.stageInfos.clear()
rdd3.count()
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ sc.listenerBus.waitUntilEmpty()
listener.stageInfos.size should be {2} // Shuffle map stage + result stage
val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 3).get
stageInfo3.rddInfos.size should be {1} // ShuffledRDD
@@ -282,7 +279,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
val rdd2 = rdd1.map(_.toString)
sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1))
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ sc.listenerBus.waitUntilEmpty()
listener.stageInfos.size should be {1}
val (stageInfo, _) = listener.stageInfos.head
@@ -310,7 +307,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
val numSlices = 16
val d = sc.parallelize(0 to 10000, numSlices).map(w)
d.count()
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ sc.listenerBus.waitUntilEmpty()
listener.stageInfos.size should be (1)
val d2 = d.map { i => w(i) -> i * 2 }.setName("shuffle input 1")
@@ -321,7 +318,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
d4.setName("A Cogroup")
d4.collectAsMap()
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ sc.listenerBus.waitUntilEmpty()
listener.stageInfos.size should be (4)
listener.stageInfos.foreach { case (stageInfo, taskInfoMetrics) =>
/**
@@ -372,7 +369,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
.reduce { case (x, y) => x }
assert(result === 1.to(maxRpcMessageSize).toArray)
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ sc.listenerBus.waitUntilEmpty()
val TASK_INDEX = 0
assert(listener.startedTasks.contains(TASK_INDEX))
assert(listener.startedGettingResultTasks.contains(TASK_INDEX))
@@ -388,7 +385,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
val result = sc.parallelize(Seq(1), 1).map(2 * _).reduce { case (x, y) => x }
assert(result === 2)
- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ sc.listenerBus.waitUntilEmpty()
val TASK_INDEX = 0
assert(listener.startedTasks.contains(TASK_INDEX))
assert(listener.startedGettingResultTasks.isEmpty)
@@ -443,7 +440,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
// Post events to all listeners, and wait until the queue is drained
(1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) }
- bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ bus.waitUntilEmpty()
// The exception should be caught, and the event should be propagated to other listeners
assert(jobCounter1.count === 5)
@@ -513,7 +510,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
// after we post one event, both interrupting listeners should get removed, and the
// event log queue should be removed
bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded))
- bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ bus.waitUntilEmpty()
assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE))
assert(bus.findListenersByClass[BasicJobCounter]().size === 2)
assert(bus.findListenersByClass[InterruptingListener]().size === 0)
@@ -522,7 +519,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
// posting more events should be fine, they'll just get processed from the OK queue.
(0 until 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) }
- bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ bus.waitUntilEmpty()
assert(counter1.count === 6)
assert(counter2.count === 6)
@@ -563,7 +560,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
/**
* Assert that the given list of numbers has an average that is greater than zero.
*/
- private def checkNonZeroAvg(m: Iterable[Long], msg: String) {
+ private def checkNonZeroAvg(m: Iterable[Long], msg: String): Unit = {
assert(m.sum / m.size.toDouble > 0.0, msg)
}
@@ -574,7 +571,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
val stageInfos = mutable.Map[StageInfo, Seq[(TaskInfo, TaskMetrics)]]()
var taskInfoMetrics = mutable.Buffer[(TaskInfo, TaskMetrics)]()
- override def onTaskEnd(task: SparkListenerTaskEnd) {
+ override def onTaskEnd(task: SparkListenerTaskEnd): Unit = {
val info = task.taskInfo
val metrics = task.taskMetrics
if (info != null && metrics != null) {
@@ -582,7 +579,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
}
}
- override def onStageCompleted(stage: SparkListenerStageCompleted) {
+ override def onStageCompleted(stage: SparkListenerStageCompleted): Unit = {
stageInfos(stage.stageInfo) = taskInfoMetrics
taskInfoMetrics = mutable.Buffer.empty
}
@@ -606,7 +603,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
notify()
}
- override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) {
+ override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult): Unit = {
startedGettingResultTasks += taskGettingResult.taskInfo.index
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala
index a6576e0d1c520..c84735c9665a7 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala
@@ -57,7 +57,7 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext
private class SaveExecutorInfo extends SparkListener {
val addedExecutorInfo = mutable.Map[String, ExecutorInfo]()
- override def onExecutorAdded(executor: SparkListenerExecutorAdded) {
+ override def onExecutorAdded(executor: SparkListenerExecutorAdded): Unit = {
addedExecutorInfo(executor.executorId) = executor.executorInfo
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index c16b552d20891..394a2a9fbf7cb 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -176,7 +176,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
if (stageAttemptNumber < 2) {
// Throw FetchFailedException to explicitly trigger stage resubmission. A normal exception
// will only trigger task resubmission in the same stage.
- throw new FetchFailedException(null, 0, 0, 0, "Fake")
+ throw new FetchFailedException(null, 0, 0L, 0, 0, "Fake")
}
Seq(stageAttemptNumber).iterator
}.collect()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index ae464352da440..8439be955c738 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -33,6 +33,7 @@ import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually._
import org.apache.spark._
+import org.apache.spark.TaskState.TaskState
import org.apache.spark.TestUtils.JavaSourceFromString
import org.apache.spark.internal.config.Network.RPC_MESSAGE_MAX_SIZE
import org.apache.spark.storage.TaskResultBlockId
@@ -52,7 +53,7 @@ private class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: Task
@volatile var removeBlockSuccessfully = false
override def enqueueSuccessfulTask(
- taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) {
+ taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer): Unit = {
if (!removedResult) {
// Only remove the result once, since we'd like to test the case where the task eventually
// succeeds.
@@ -78,6 +79,16 @@ private class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: Task
}
}
+private class DummyTaskSchedulerImpl(sc: SparkContext)
+ extends TaskSchedulerImpl(sc, 1, true) {
+ override def handleFailedTask(
+ taskSetManager: TaskSetManager,
+ tid: Long,
+ taskState: TaskState,
+ reason: TaskFailedReason): Unit = {
+ // do nothing
+ }
+}
/**
* A [[TaskResultGetter]] that stores the [[DirectTaskResult]]s it receives from executors
@@ -130,6 +141,31 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
"Expect result to be removed from the block manager.")
}
+ test("handling total size of results larger than maxResultSize") {
+ sc = new SparkContext("local", "test", conf)
+ val scheduler = new DummyTaskSchedulerImpl(sc)
+ val spyScheduler = spy(scheduler)
+ val resultGetter = new TaskResultGetter(sc.env, spyScheduler)
+ scheduler.taskResultGetter = resultGetter
+ val myTsm = new TaskSetManager(spyScheduler, FakeTask.createTaskSet(2), 1) {
+ // always returns false
+ override def canFetchMoreResults(size: Long): Boolean = false
+ }
+ val indirectTaskResult = IndirectTaskResult(TaskResultBlockId(0), 0)
+ val directTaskResult = new DirectTaskResult(ByteBuffer.allocate(0), Nil, Array())
+ val ser = sc.env.closureSerializer.newInstance()
+ val serializedIndirect = ser.serialize(indirectTaskResult)
+ val serializedDirect = ser.serialize(directTaskResult)
+ resultGetter.enqueueSuccessfulTask(myTsm, 0, serializedDirect)
+ resultGetter.enqueueSuccessfulTask(myTsm, 1, serializedIndirect)
+ eventually(timeout(1.second)) {
+ verify(spyScheduler, times(1)).handleFailedTask(
+ myTsm, 0, TaskState.KILLED, TaskKilled("Tasks result size has exceeded maxResultSize"))
+ verify(spyScheduler, times(1)).handleFailedTask(
+ myTsm, 1, TaskState.KILLED, TaskKilled("Tasks result size has exceeded maxResultSize"))
+ }
+ }
+
test("task retried if result missing from block manager") {
// Set the maximum number of task failures to > 0, so that the task set isn't aborted
// after the result is missing.
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index cac6285e58417..e7ecf847ff4f4 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -26,7 +26,7 @@ import org.mockito.ArgumentMatchers.{any, anyInt, anyString, eq => meq}
import org.mockito.Mockito.{atLeast, atMost, never, spy, times, verify, when}
import org.scalatest.BeforeAndAfterEach
import org.scalatest.concurrent.Eventually
-import org.scalatest.mockito.MockitoSugar
+import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark._
import org.apache.spark.internal.Logging
@@ -36,9 +36,9 @@ import org.apache.spark.resource.TestResourceIDs._
import org.apache.spark.util.ManualClock
class FakeSchedulerBackend extends SchedulerBackend {
- def start() {}
- def stop() {}
- def reviveOffers() {}
+ def start(): Unit = {}
+ def stop(): Unit = {}
+ def reviveOffers(): Unit = {}
def defaultParallelism(): Int = 1
def maxNumConcurrentTasks(): Int = 0
}
@@ -228,19 +228,19 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
taskScheduler.taskSetManagerForAttempt(taskset.stageId, taskset.stageAttemptId).get.isZombie
}
- val attempt1 = FakeTask.createTaskSet(1, 0)
+ val attempt1 = FakeTask.createTaskSet(1, stageId = 0, stageAttemptId = 0)
taskScheduler.submitTasks(attempt1)
// The first submitted taskset is active
assert(!isTasksetZombie(attempt1))
- val attempt2 = FakeTask.createTaskSet(1, 1)
+ val attempt2 = FakeTask.createTaskSet(1, stageId = 0, stageAttemptId = 1)
taskScheduler.submitTasks(attempt2)
// The first submitted taskset is zombie now
assert(isTasksetZombie(attempt1))
// The newly submitted taskset is active
assert(!isTasksetZombie(attempt2))
- val attempt3 = FakeTask.createTaskSet(1, 2)
+ val attempt3 = FakeTask.createTaskSet(1, stageId = 0, stageAttemptId = 2)
taskScheduler.submitTasks(attempt3)
// The first submitted taskset remains zombie
assert(isTasksetZombie(attempt1))
@@ -255,7 +255,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
val numFreeCores = 1
val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", numFreeCores))
- val attempt1 = FakeTask.createTaskSet(10)
+ val attempt1 = FakeTask.createTaskSet(10, stageId = 0, stageAttemptId = 0)
// submit attempt 1, offer some resources, some tasks get scheduled
taskScheduler.submitTasks(attempt1)
@@ -271,7 +271,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
assert(0 === taskDescriptions2.length)
// if we schedule another attempt for the same stage, it should get scheduled
- val attempt2 = FakeTask.createTaskSet(10, 1)
+ val attempt2 = FakeTask.createTaskSet(10, stageId = 0, stageAttemptId = 1)
// submit attempt 2, offer some resources, some tasks get scheduled
taskScheduler.submitTasks(attempt2)
@@ -287,7 +287,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
val numFreeCores = 10
val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", numFreeCores))
- val attempt1 = FakeTask.createTaskSet(10)
+ val attempt1 = FakeTask.createTaskSet(10, stageId = 0, stageAttemptId = 0)
// submit attempt 1, offer some resources, some tasks get scheduled
taskScheduler.submitTasks(attempt1)
@@ -303,7 +303,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
assert(0 === taskDescriptions2.length)
// submit attempt 2
- val attempt2 = FakeTask.createTaskSet(10, 1)
+ val attempt2 = FakeTask.createTaskSet(10, stageId = 0, stageAttemptId = 1)
taskScheduler.submitTasks(attempt2)
// attempt 1 finished (this can happen even if it was marked zombie earlier -- all tasks were
@@ -497,7 +497,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
test("abort stage when all executors are blacklisted and we cannot acquire new executor") {
taskScheduler = setupSchedulerWithMockTaskSetBlacklist()
- val taskSet = FakeTask.createTaskSet(numTasks = 10, stageAttemptId = 0)
+ val taskSet = FakeTask.createTaskSet(numTasks = 10)
taskScheduler.submitTasks(taskSet)
val tsm = stageToMockTaskSetManager(0)
@@ -539,7 +539,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
config.UNSCHEDULABLE_TASKSET_TIMEOUT.key -> "0")
// We have only 1 task remaining with 1 executor
- val taskSet = FakeTask.createTaskSet(numTasks = 1, stageAttemptId = 0)
+ val taskSet = FakeTask.createTaskSet(numTasks = 1)
taskScheduler.submitTasks(taskSet)
val tsm = stageToMockTaskSetManager(0)
@@ -571,7 +571,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
config.UNSCHEDULABLE_TASKSET_TIMEOUT.key -> "10")
// We have only 1 task remaining with 1 executor
- val taskSet = FakeTask.createTaskSet(numTasks = 1, stageAttemptId = 0)
+ val taskSet = FakeTask.createTaskSet(numTasks = 1)
taskScheduler.submitTasks(taskSet)
val tsm = stageToMockTaskSetManager(0)
@@ -910,7 +910,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
test("SPARK-16106 locality levels updated if executor added to existing host") {
val taskScheduler = setupScheduler()
- taskScheduler.submitTasks(FakeTask.createTaskSet(2, 0,
+ taskScheduler.submitTasks(FakeTask.createTaskSet(2, stageId = 0, stageAttemptId = 0,
(0 until 2).map { _ => Seq(TaskLocation("host0", "executor2")) }: _*
))
@@ -948,7 +948,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
test("scheduler checks for executors that can be expired from blacklist") {
taskScheduler = setupScheduler()
- taskScheduler.submitTasks(FakeTask.createTaskSet(1, 0))
+ taskScheduler.submitTasks(FakeTask.createTaskSet(1, stageId = 0, stageAttemptId = 0))
taskScheduler.resourceOffers(IndexedSeq(
new WorkerOffer("executor0", "host0", 1)
)).flatten
@@ -962,8 +962,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
taskScheduler.initialize(new FakeSchedulerBackend)
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
new DAGScheduler(sc, taskScheduler) {
- override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
- override def executorAdded(execId: String, host: String) {}
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo): Unit = {}
+ override def executorAdded(execId: String, host: String): Unit = {}
}
val e0Offers = IndexedSeq(WorkerOffer("executor0", "host0", 1))
@@ -993,8 +993,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
taskScheduler.initialize(new FakeSchedulerBackend)
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
new DAGScheduler(sc, taskScheduler) {
- override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
- override def executorAdded(execId: String, host: String) {}
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo): Unit = {}
+ override def executorAdded(execId: String, host: String): Unit = {}
}
val e0Offers = IndexedSeq(WorkerOffer("executor0", "host0", 1))
@@ -1044,8 +1044,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
}
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
new DAGScheduler(sc, taskScheduler) {
- override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
- override def executorAdded(execId: String, host: String) {}
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo): Unit = {}
+ override def executorAdded(execId: String, host: String): Unit = {}
}
taskScheduler.initialize(new FakeSchedulerBackend)
@@ -1084,8 +1084,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
}
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
new DAGScheduler(sc, taskScheduler) {
- override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
- override def executorAdded(execId: String, host: String) {}
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo): Unit = {}
+ override def executorAdded(execId: String, host: String): Unit = {}
}
taskScheduler.initialize(new FakeSchedulerBackend)
// make an offer on the preferred host so the scheduler knows its alive. This is necessary
@@ -1154,6 +1154,29 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
assert(3 === taskDescriptions.length)
}
+ test("SPARK-29263: barrier TaskSet can't schedule when higher prio taskset takes the slots") {
+ val taskCpus = 2
+ val taskScheduler = setupSchedulerWithMaster(
+ s"local[$taskCpus]",
+ config.CPUS_PER_TASK.key -> taskCpus.toString)
+
+ val numFreeCores = 3
+ val workerOffers = IndexedSeq(
+ new WorkerOffer("executor0", "host0", numFreeCores, Some("192.168.0.101:49625")),
+ new WorkerOffer("executor1", "host1", numFreeCores, Some("192.168.0.101:49627")),
+ new WorkerOffer("executor2", "host2", numFreeCores, Some("192.168.0.101:49629")))
+ val barrier = FakeTask.createBarrierTaskSet(3, stageId = 0, stageAttemptId = 0, priority = 1)
+ val highPrio = FakeTask.createTaskSet(1, stageId = 1, stageAttemptId = 0, priority = 0)
+
+ // submit highPrio and barrier taskSet
+ taskScheduler.submitTasks(highPrio)
+ taskScheduler.submitTasks(barrier)
+ val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
+ // it schedules the highPrio task first, and then will not have enough slots to schedule
+ // the barrier taskset
+ assert(1 === taskDescriptions.length)
+ }
+
test("cancelTasks shall kill all the running tasks and fail the stage") {
val taskScheduler = setupScheduler()
@@ -1169,7 +1192,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
}
})
- val attempt1 = FakeTask.createTaskSet(10, 0)
+ val attempt1 = FakeTask.createTaskSet(10)
taskScheduler.submitTasks(attempt1)
val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1),
@@ -1200,7 +1223,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
}
})
- val attempt1 = FakeTask.createTaskSet(10, 0)
+ val attempt1 = FakeTask.createTaskSet(10)
taskScheduler.submitTasks(attempt1)
val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1),
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala
index b3bc76687ce1b..ed97a4c206ca3 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
import org.mockito.ArgumentMatchers.isA
import org.mockito.Mockito.{never, verify}
import org.scalatest.BeforeAndAfterEach
-import org.scalatest.mockito.MockitoSugar
+import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.internal.config
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index fedfa083e8d8f..441ec6ab6e18b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -38,7 +38,7 @@ import org.apache.spark.util.{AccumulatorV2, ManualClock}
class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
extends DAGScheduler(sc) {
- override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo): Unit = {
taskScheduler.startedTasks += taskInfo.index
}
@@ -48,13 +48,13 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
result: Any,
accumUpdates: Seq[AccumulatorV2[_, _]],
metricPeaks: Array[Long],
- taskInfo: TaskInfo) {
+ taskInfo: TaskInfo): Unit = {
taskScheduler.endedTasks(taskInfo.index) = reason
}
- override def executorAdded(execId: String, host: String) {}
+ override def executorAdded(execId: String, host: String): Unit = {}
- override def executorLost(execId: String, reason: ExecutorLossReason) {}
+ override def executorLost(execId: String, reason: ExecutorLossReason): Unit = {}
override def taskSetFailed(
taskSet: TaskSet,
@@ -74,13 +74,13 @@ object FakeRackUtil {
var numBatchInvocation = 0
var numSingleHostInvocation = 0
- def cleanUp() {
+ def cleanUp(): Unit = {
hostToRack.clear()
numBatchInvocation = 0
numSingleHostInvocation = 0
}
- def assignHostToRack(host: String, rack: String) {
+ def assignHostToRack(host: String, rack: String): Unit = {
hostToRack(host) = rack
}
@@ -124,7 +124,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
dagScheduler = new FakeDAGScheduler(sc, this)
- def removeExecutor(execId: String) {
+ def removeExecutor(execId: String): Unit = {
executors -= execId
val host = executorIdToHost.get(execId)
assert(host != None)
@@ -149,7 +149,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
hostsByRack.get(rack) != None
}
- def addExecutor(execId: String, host: String) {
+ def addExecutor(execId: String, host: String): Unit = {
executors.put(execId, host)
val executorsOnHost = hostToExecutors.getOrElseUpdate(host, new mutable.HashSet[String])
executorsOnHost += execId
@@ -1262,7 +1262,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
// now fail those tasks
tsmSpy.handleFailedTask(taskDescs(0).taskId, TaskState.FAILED,
- FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0, 0, "ignored"))
+ FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0L, 0, 0, "ignored"))
tsmSpy.handleFailedTask(taskDescs(1).taskId, TaskState.FAILED,
ExecutorLostFailure(taskDescs(1).executorId, exitCausedByApp = false, reason = None))
tsmSpy.handleFailedTask(taskDescs(2).taskId, TaskState.FAILED,
@@ -1302,7 +1302,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
// Fail the task with fetch failure
tsm.handleFailedTask(taskDescs(0).taskId, TaskState.FAILED,
- FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0, 0, "ignored"))
+ FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0L, 0, 0, "ignored"))
assert(blacklistTracker.isNodeBlacklisted("host1"))
}
diff --git a/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala
index be6b8a6b5b108..213f0ba2ec180 100644
--- a/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala
+++ b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala
@@ -27,7 +27,7 @@ trait EncryptionFunSuite {
* Runs a test twice, initializing a SparkConf object with encryption off, then on. It's ok
* for the test to modify the provided SparkConf.
*/
- final protected def encryptionTest(name: String)(fn: SparkConf => Unit) {
+ final protected def encryptionTest(name: String)(fn: SparkConf => Unit): Unit = {
encryptionTestHelper(name) { case (name, conf) =>
test(name)(fn(conf))
}
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala
index 2915b99dcfb60..953b651c72a83 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala
@@ -25,6 +25,7 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
import org.apache.spark.internal.config._
import org.apache.spark.internal.config.Kryo._
+import org.apache.spark.launcher.SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS
import org.apache.spark.serializer.KryoTest._
import org.apache.spark.util.ThreadUtils
@@ -71,6 +72,9 @@ object KryoSerializerBenchmark extends BenchmarkBase {
def createSparkContext(usePool: Boolean): SparkContext = {
val conf = new SparkConf()
+ // SPARK-29282 This is for consistency between JDK8 and JDK11.
+ conf.set(EXECUTOR_EXTRA_JAVA_OPTIONS,
+ "-XX:+UseParallelGC -XX:-UseDynamicNumberOfGCThreads")
conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
conf.set(KRYO_USER_REGISTRATORS, classOf[MyRegistrator].getName)
conf.set(KRYO_USE_POOL, usePool)
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
index 5d76c096d46ac..d4fafab4a5d64 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
@@ -56,7 +56,7 @@ object KryoDistributedTest {
class MyCustomClass
class AppJarRegistrator extends KryoRegistrator {
- override def registerClasses(k: Kryo) {
+ override def registerClasses(k: Kryo): Unit = {
k.register(Utils.classForName(AppJarRegistrator.customClassName,
noSparkClassLoader = true))
}
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index 2442670b6d3f0..b5313fc24cd84 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -86,7 +86,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
conf.set(KRYO_REGISTRATION_REQUIRED, true)
val ser = new KryoSerializer(conf).newInstance()
- def check[T: ClassTag](t: T) {
+ def check[T: ClassTag](t: T): Unit = {
assert(ser.deserialize[T](ser.serialize(t)) === t)
}
check(1)
@@ -119,7 +119,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
conf.set(KRYO_REGISTRATION_REQUIRED, true)
val ser = new KryoSerializer(conf).newInstance()
- def check[T: ClassTag](t: T) {
+ def check[T: ClassTag](t: T): Unit = {
assert(ser.deserialize[T](ser.serialize(t)) === t)
}
check((1, 1))
@@ -146,7 +146,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
conf.set(KRYO_REGISTRATION_REQUIRED, true)
val ser = new KryoSerializer(conf).newInstance()
- def check[T: ClassTag](t: T) {
+ def check[T: ClassTag](t: T): Unit = {
assert(ser.deserialize[T](ser.serialize(t)) === t)
}
check(List[Int]())
@@ -173,7 +173,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
test("Bug: SPARK-10251") {
val ser = new KryoSerializer(conf.clone.set(KRYO_REGISTRATION_REQUIRED, true))
.newInstance()
- def check[T: ClassTag](t: T) {
+ def check[T: ClassTag](t: T): Unit = {
assert(ser.deserialize[T](ser.serialize(t)) === t)
}
check((1, 3))
@@ -202,7 +202,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
test("ranges") {
val ser = new KryoSerializer(conf).newInstance()
- def check[T: ClassTag](t: T) {
+ def check[T: ClassTag](t: T): Unit = {
assert(ser.deserialize[T](ser.serialize(t)) === t)
// Check that very long ranges don't get written one element at a time
assert(ser.serialize(t).limit() < 200)
@@ -238,7 +238,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
test("custom registrator") {
val ser = new KryoSerializer(conf).newInstance()
- def check[T: ClassTag](t: T) {
+ def check[T: ClassTag](t: T): Unit = {
assert(ser.deserialize[T](ser.serialize(t)) === t)
}
@@ -350,8 +350,11 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
val ser = new KryoSerializer(conf).newInstance()
val denseBlockSizes = new Array[Long](5000)
val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L)
+ var mapTaskId = 0
Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes =>
- ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes))
+ mapTaskId += 1
+ ser.serialize(HighlyCompressedMapStatus(
+ BlockManagerId("exec-1", "host", 1234), blockSizes, mapTaskId))
}
}
@@ -460,7 +463,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
val tests = mutable.ListBuffer[Future[Boolean]]()
- def check[T: ClassTag](t: T) {
+ def check[T: ClassTag](t: T): Unit = {
tests += Future {
val serializerInstance = ser.newInstance()
serializerInstance.deserialize[T](serializerInstance.serialize(t)) === t
@@ -579,7 +582,7 @@ object KryoTest {
}
class MyRegistrator extends KryoRegistrator {
- override def registerClasses(k: Kryo) {
+ override def registerClasses(k: Kryo): Unit = {
k.register(classOf[CaseClass])
k.register(classOf[ClassWithNoArgConstructor])
k.register(classOf[ClassWithoutNoArgConstructor])
@@ -588,7 +591,7 @@ object KryoTest {
}
class RegistratorWithoutAutoReset extends KryoRegistrator {
- override def registerClasses(k: Kryo) {
+ override def registerClasses(k: Kryo): Unit = {
k.setAutoReset(false)
}
}
diff --git a/core/src/test/scala/org/apache/spark/serializer/UnsafeKryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/UnsafeKryoSerializerSuite.scala
index 126ba0e8b1e93..65f3793c421fa 100644
--- a/core/src/test/scala/org/apache/spark/serializer/UnsafeKryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/UnsafeKryoSerializerSuite.scala
@@ -23,12 +23,12 @@ class UnsafeKryoSerializerSuite extends KryoSerializerSuite {
// This test suite should run all tests in KryoSerializerSuite with kryo unsafe.
- override def beforeAll() {
+ override def beforeAll(): Unit = {
conf.set(KRYO_USE_UNSAFE, true)
super.beforeAll()
}
- override def afterAll() {
+ override def afterAll(): Unit = {
conf.set(KRYO_USE_UNSAFE, false)
super.afterAll()
}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
index 6d2ef17a7a790..d0cbb30fe0232 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
@@ -102,12 +102,13 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
// Make a mocked MapOutputTracker for the shuffle reader to use to determine what
// shuffle data to read.
val mapOutputTracker = mock(classOf[MapOutputTracker])
- when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn {
+ when(mapOutputTracker.getMapSizesByExecutorId(
+ shuffleId, reduceId, reduceId + 1, useOldFetchProtocol = false)).thenReturn {
// Test a scenario where all data is local, to avoid creating a bunch of additional mocks
// for the code to read data over the network.
val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId =>
val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
- (shuffleBlockId, byteOutputStream.size().toLong)
+ (shuffleBlockId, byteOutputStream.size().toLong, mapId)
}
Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator
}
@@ -118,7 +119,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
when(dependency.serializer).thenReturn(serializer)
when(dependency.aggregator).thenReturn(None)
when(dependency.keyOrdering).thenReturn(None)
- new BaseShuffleHandle(shuffleId, numMaps, dependency)
+ new BaseShuffleHandle(shuffleId, dependency)
}
val serializerManager = new SerializerManager(
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index b9f81fa0d0a06..f8474022867f4 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer
import org.mockito.{Mock, MockitoAnnotations}
import org.mockito.Answers.RETURNS_SMART_NULLS
-import org.mockito.ArgumentMatchers.{any, anyInt}
+import org.mockito.ArgumentMatchers.{any, anyInt, anyLong}
import org.mockito.Mockito._
import org.scalatest.BeforeAndAfterEach
@@ -65,7 +65,6 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
taskMetrics = new TaskMetrics
shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int](
shuffleId = 0,
- numMaps = 2,
dependency = dependency
)
val memoryManager = new TestMemoryManager(conf)
@@ -78,7 +77,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)
when(blockResolver.writeIndexFileAndCommit(
- anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])))
+ anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])))
.thenAnswer { invocationOnMock =>
val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
if (tmp != null) {
@@ -139,8 +138,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
val writer = new BypassMergeSortShuffleWriter[Int, Int](
blockManager,
shuffleHandle,
- 0, // MapId
- 0L, // MapTaskAttemptId
+ 0L, // MapId
conf,
taskContext.taskMetrics().shuffleWriteMetrics,
shuffleExecutorComponents)
@@ -166,8 +164,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
val writer = new BypassMergeSortShuffleWriter[Int, Int](
blockManager,
shuffleHandle,
- 0, // MapId
- 0L,
+ 0L, // MapId
transferConf,
taskContext.taskMetrics().shuffleWriteMetrics,
shuffleExecutorComponents)
@@ -202,8 +199,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
val writer = new BypassMergeSortShuffleWriter[Int, Int](
blockManager,
shuffleHandle,
- 0, // MapId
- 0L,
+ 0L, // MapId
conf,
taskContext.taskMetrics().shuffleWriteMetrics,
shuffleExecutorComponents)
@@ -224,8 +220,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
val writer = new BypassMergeSortShuffleWriter[Int, Int](
blockManager,
shuffleHandle,
- 0, // MapId
- 0L,
+ 0L, // MapId
conf,
taskContext.taskMetrics().shuffleWriteMetrics,
shuffleExecutorComponents)
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala
index 8b955c98f7953..49055ab71c3fe 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.shuffle.sort
import java.lang.{Long => JLong}
import org.mockito.Mockito.when
-import org.scalatest.mockito.MockitoSugar
+import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark._
import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
index 0dd6040808f9e..4c5694fcf0305 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
@@ -57,7 +57,7 @@ class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with
when(dependency.serializer).thenReturn(serializer)
when(dependency.aggregator).thenReturn(None)
when(dependency.keyOrdering).thenReturn(None)
- new BaseShuffleHandle(shuffleId, numMaps = numMaps, dependency)
+ new BaseShuffleHandle(shuffleId, dependency)
}
shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents(
conf, blockManager, shuffleBlockResolver)
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
index 5156cc2cc47a6..f92455912f510 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
@@ -23,7 +23,7 @@ import java.nio.file.Files
import java.util.Arrays
import org.mockito.Answers.RETURNS_SMART_NULLS
-import org.mockito.ArgumentMatchers.{any, anyInt}
+import org.mockito.ArgumentMatchers.{any, anyInt, anyLong}
import org.mockito.Mock
import org.mockito.Mockito.when
import org.mockito.MockitoAnnotations
@@ -73,9 +73,9 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA
conf = new SparkConf()
.set("spark.app.id", "example.spark.app")
.set("spark.shuffle.unsafe.file.output.buffer", "16k")
- when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile)
+ when(blockResolver.getDataFile(anyInt, anyLong)).thenReturn(mergedOutputFile)
when(blockResolver.writeIndexFileAndCommit(
- anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])))
+ anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])))
.thenAnswer { invocationOnMock =>
partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]]
val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File]
diff --git a/core/src/test/scala/org/apache/spark/status/LiveEntitySuite.scala b/core/src/test/scala/org/apache/spark/status/LiveEntitySuite.scala
index bb2d2633001f0..8e23de0053e00 100644
--- a/core/src/test/scala/org/apache/spark/status/LiveEntitySuite.scala
+++ b/core/src/test/scala/org/apache/spark/status/LiveEntitySuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.status
import org.apache.spark.SparkFunSuite
-import org.apache.spark.status.api.v1.RDDPartitionInfo
class LiveEntitySuite extends SparkFunSuite {
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
index ff4755833a916..0f3767c4f8c84 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
@@ -22,13 +22,13 @@ import java.util.UUID
import org.apache.spark.SparkFunSuite
class BlockIdSuite extends SparkFunSuite {
- def assertSame(id1: BlockId, id2: BlockId) {
+ def assertSame(id1: BlockId, id2: BlockId): Unit = {
assert(id1.name === id2.name)
assert(id1.hashCode === id2.hashCode)
assert(id1 === id2)
}
- def assertDifferent(id1: BlockId, id2: BlockId) {
+ def assertDifferent(id1: BlockId, id2: BlockId): Unit = {
assert(id1.name != id2.name)
assert(id1.hashCode != id2.hashCode)
assert(id1 != id2)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index 05a9ac685e5e7..d8f42ea9557d9 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -308,7 +308,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite
* is correct. Then it also drops the block from memory of each store (using LRU) and
* again checks whether the master's knowledge gets updated.
*/
- protected def testReplication(maxReplication: Int, storageLevels: Seq[StorageLevel]) {
+ protected def testReplication(maxReplication: Int, storageLevels: Seq[StorageLevel]): Unit = {
import org.apache.spark.storage.StorageLevel._
assert(maxReplication > 1,
@@ -431,7 +431,7 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav
}
}
- def testProactiveReplication(replicationFactor: Int) {
+ def testProactiveReplication(replicationFactor: Int): Unit = {
val blockSize = 1000
val storeSize = 10000
val initialStores = (1 to 10).map { i => makeBlockManager(storeSize, s"store$i") }
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 509d4efcab67a..43a0cc7e31b40 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -28,8 +28,8 @@ import scala.language.implicitConversions
import scala.reflect.ClassTag
import org.apache.commons.lang3.RandomUtils
-import org.mockito.{ArgumentMatchers => mc}
-import org.mockito.Mockito.{doAnswer, mock, spy, times, verify, when}
+import org.mockito.{ArgumentCaptor, ArgumentMatchers => mc}
+import org.mockito.Mockito.{doAnswer, mock, never, spy, times, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.scalatest._
import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
@@ -143,9 +143,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
// need to create a SparkContext is to initialize LiveListenerBus.
sc = mock(classOf[SparkContext])
when(sc.conf).thenReturn(conf)
- master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager",
- new BlockManagerMasterEndpoint(rpcEnv, true, conf,
- new LiveListenerBus(conf), None)), conf, true)
+ master = spy(new BlockManagerMaster(
+ rpcEnv.setupEndpoint("blockmanager",
+ new BlockManagerMasterEndpoint(rpcEnv, true, conf,
+ new LiveListenerBus(conf), None)), conf, true))
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
@@ -289,14 +290,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
eventually(timeout(1.second), interval(10.milliseconds)) {
assert(!store.hasLocalBlock("a1-to-remove"))
master.getLocations("a1-to-remove") should have size 0
+ assertUpdateBlockInfoReportedForRemovingBlock(store, "a1-to-remove",
+ removedFromMemory = true, removedFromDisk = false)
}
eventually(timeout(1.second), interval(10.milliseconds)) {
assert(!store.hasLocalBlock("a2-to-remove"))
master.getLocations("a2-to-remove") should have size 0
+ assertUpdateBlockInfoReportedForRemovingBlock(store, "a2-to-remove",
+ removedFromMemory = true, removedFromDisk = false)
}
eventually(timeout(1.second), interval(10.milliseconds)) {
assert(store.hasLocalBlock("a3-to-remove"))
master.getLocations("a3-to-remove") should have size 0
+ assertUpdateBlockInfoNotReported(store, "a3-to-remove")
}
eventually(timeout(1.second), interval(10.milliseconds)) {
val memStatus = master.getMemoryStatus.head._2
@@ -375,16 +381,21 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(!executorStore.hasLocalBlock(broadcast0BlockId))
assert(executorStore.hasLocalBlock(broadcast1BlockId))
assert(executorStore.hasLocalBlock(broadcast2BlockId))
+ assertUpdateBlockInfoReportedForRemovingBlock(executorStore, broadcast0BlockId,
+ removedFromMemory = false, removedFromDisk = true)
// nothing should be removed from the driver store
assert(driverStore.hasLocalBlock(broadcast0BlockId))
assert(driverStore.hasLocalBlock(broadcast1BlockId))
assert(driverStore.hasLocalBlock(broadcast2BlockId))
+ assertUpdateBlockInfoNotReported(driverStore, broadcast0BlockId)
// remove broadcast 0 block from the driver as well
master.removeBroadcast(0, removeFromMaster = true, blocking = true)
assert(!driverStore.hasLocalBlock(broadcast0BlockId))
assert(driverStore.hasLocalBlock(broadcast1BlockId))
+ assertUpdateBlockInfoReportedForRemovingBlock(driverStore, broadcast0BlockId,
+ removedFromMemory = false, removedFromDisk = true)
// remove broadcast 1 block from both the stores asynchronously
// and verify all broadcast 1 blocks have been removed
@@ -392,6 +403,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
eventually(timeout(1.second), interval(10.milliseconds)) {
assert(!driverStore.hasLocalBlock(broadcast1BlockId))
assert(!executorStore.hasLocalBlock(broadcast1BlockId))
+ assertUpdateBlockInfoReportedForRemovingBlock(driverStore, broadcast1BlockId,
+ removedFromMemory = false, removedFromDisk = true)
+ assertUpdateBlockInfoReportedForRemovingBlock(executorStore, broadcast1BlockId,
+ removedFromMemory = false, removedFromDisk = true)
}
// remove broadcast 2 from both the stores asynchronously
@@ -402,11 +417,46 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(!driverStore.hasLocalBlock(broadcast2BlockId2))
assert(!executorStore.hasLocalBlock(broadcast2BlockId))
assert(!executorStore.hasLocalBlock(broadcast2BlockId2))
+ assertUpdateBlockInfoReportedForRemovingBlock(driverStore, broadcast2BlockId,
+ removedFromMemory = false, removedFromDisk = true)
+ assertUpdateBlockInfoReportedForRemovingBlock(driverStore, broadcast2BlockId2,
+ removedFromMemory = false, removedFromDisk = true)
+ assertUpdateBlockInfoReportedForRemovingBlock(executorStore, broadcast2BlockId,
+ removedFromMemory = false, removedFromDisk = true)
+ assertUpdateBlockInfoReportedForRemovingBlock(executorStore, broadcast2BlockId2,
+ removedFromMemory = false, removedFromDisk = true)
}
executorStore.stop()
driverStore.stop()
}
+ private def assertUpdateBlockInfoReportedForRemovingBlock(
+ store: BlockManager,
+ blockId: BlockId,
+ removedFromMemory: Boolean,
+ removedFromDisk: Boolean): Unit = {
+ def assertSizeReported(captor: ArgumentCaptor[Long], expectRemoved: Boolean): Unit = {
+ assert(captor.getAllValues().size() === 1)
+ if (expectRemoved) {
+ assert(captor.getValue() > 0)
+ } else {
+ assert(captor.getValue() === 0)
+ }
+ }
+
+ val memSizeCaptor = ArgumentCaptor.forClass(classOf[Long]).asInstanceOf[ArgumentCaptor[Long]]
+ val diskSizeCaptor = ArgumentCaptor.forClass(classOf[Long]).asInstanceOf[ArgumentCaptor[Long]]
+ verify(master).updateBlockInfo(mc.eq(store.blockManagerId), mc.eq(blockId),
+ mc.eq(StorageLevel.NONE), memSizeCaptor.capture(), diskSizeCaptor.capture())
+ assertSizeReported(memSizeCaptor, removedFromMemory)
+ assertSizeReported(diskSizeCaptor, removedFromDisk)
+ }
+
+ private def assertUpdateBlockInfoNotReported(store: BlockManager, blockId: BlockId): Unit = {
+ verify(master, never()).updateBlockInfo(mc.eq(store.blockManagerId), mc.eq(blockId),
+ mc.eq(StorageLevel.NONE), mc.anyInt(), mc.anyInt())
+ }
+
test("reregistration on heart beat") {
val store = makeBlockManager(2000)
val a1 = new Array[Byte](400)
@@ -451,18 +501,18 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
for (i <- 1 to 100) {
master.removeExecutor(store.blockManagerId.executorId)
val t1 = new Thread {
- override def run() {
+ override def run(): Unit = {
store.putIterator(
"a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
}
}
val t2 = new Thread {
- override def run() {
+ override def run(): Unit = {
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
}
}
val t3 = new Thread {
- override def run() {
+ override def run(): Unit = {
store.reregister()
}
}
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index 0c4f3c48ef802..c757dee43808d 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.storage
import java.io.{File, FileWriter}
-import java.util.UUID
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
@@ -33,14 +32,14 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B
var diskBlockManager: DiskBlockManager = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
rootDir0 = Utils.createTempDir()
rootDir1 = Utils.createTempDir()
rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
}
- override def afterAll() {
+ override def afterAll(): Unit = {
try {
Utils.deleteRecursively(rootDir0)
Utils.deleteRecursively(rootDir1)
@@ -49,14 +48,14 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B
}
}
- override def beforeEach() {
+ override def beforeEach(): Unit = {
super.beforeEach()
val conf = testConf.clone
conf.set("spark.local.dir", rootDirs)
diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true)
}
- override def afterEach() {
+ override def afterEach(): Unit = {
try {
diskBlockManager.stop()
} finally {
@@ -86,7 +85,7 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B
assert(diskBlockManager.getAllBlocks().isEmpty)
}
- def writeToFile(file: File, numBytes: Int) {
+ def writeToFile(file: File, numBytes: Int): Unit = {
val writer = new FileWriter(file, true)
for (i <- 0 until numBytes) writer.write(i)
writer.close()
diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
index 56860b2e55709..74442c2966a72 100644
--- a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.storage
import org.mockito.ArgumentMatchers.{eq => meq}
import org.mockito.Mockito._
-import org.scalatest.mockito.MockitoSugar
+import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.SparkFunSuite
import org.apache.spark.memory.MemoryMode.ON_HEAP
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index ed402440e74f1..e5a615c2c2cbb 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -98,9 +98,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val transfer = createMockTransfer(remoteBlocks)
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq),
- (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
+ (localBmId, localBlocks.keys.map(blockId => (blockId, 1L, 0)).toSeq),
+ (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq)
).toIterator
val taskContext = TaskContext.empty()
@@ -179,8 +179,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
}
})
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)).toIterator
val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
@@ -247,8 +247,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
}
})
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq))
+ .toIterator
val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
@@ -336,8 +337,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
}
})
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)).toIterator
val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
@@ -389,8 +390,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val corruptBuffer1 = mockCorruptBuffer(streamLength, 0)
val blockManagerId1 = BlockManagerId("remote-client-1", "remote-client-1", 1)
val shuffleBlockId1 = ShuffleBlockId(0, 1, 0)
- val blockLengths1 = Seq[Tuple2[BlockId, Long]](
- shuffleBlockId1 -> corruptBuffer1.size()
+ val blockLengths1 = Seq[Tuple3[BlockId, Long, Int]](
+ (shuffleBlockId1, corruptBuffer1.size(), 1)
)
val streamNotCorruptTill = 8 * 1024
@@ -398,13 +399,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val corruptBuffer2 = mockCorruptBuffer(streamLength, streamNotCorruptTill)
val blockManagerId2 = BlockManagerId("remote-client-2", "remote-client-2", 2)
val shuffleBlockId2 = ShuffleBlockId(0, 2, 0)
- val blockLengths2 = Seq[Tuple2[BlockId, Long]](
- shuffleBlockId2 -> corruptBuffer2.size()
+ val blockLengths2 = Seq[Tuple3[BlockId, Long, Int]](
+ (shuffleBlockId2, corruptBuffer2.size(), 2)
)
val transfer = createMockTransfer(
Map(shuffleBlockId1 -> corruptBuffer1, shuffleBlockId2 -> corruptBuffer2))
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
(blockManagerId1, blockLengths1),
(blockManagerId2, blockLengths2)
).toIterator
@@ -465,11 +466,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val localBmId = BlockManagerId("test-client", "test-client", 1)
doReturn(localBmId).when(blockManager).blockManagerId
doReturn(managedBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0))
- val localBlockLengths = Seq[Tuple2[BlockId, Long]](
- ShuffleBlockId(0, 0, 0) -> 10000
+ val localBlockLengths = Seq[Tuple3[BlockId, Long, Int]](
+ (ShuffleBlockId(0, 0, 0), 10000, 0)
)
val transfer = createMockTransfer(Map(ShuffleBlockId(0, 0, 0) -> managedBuffer))
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
(localBmId, localBlockLengths)
).toIterator
@@ -531,8 +532,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
}
})
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq))
+ .toIterator
val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
@@ -591,7 +593,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
})
def fetchShuffleBlock(
- blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = {
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = {
// Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the
// construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks
// are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here.
@@ -611,15 +613,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
taskContext.taskMetrics.createTempShuffleReadMetrics())
}
- val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)).toIterator
+ val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
+ (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L, 0)).toSeq)).toIterator
fetchShuffleBlock(blocksByAddress1)
// `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch
// shuffle block to disk.
assert(tempFileManager == null)
- val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)).toIterator
+ val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
+ (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L, 0)).toSeq)).toIterator
fetchShuffleBlock(blocksByAddress2)
// `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch
// shuffle block to disk.
@@ -640,8 +642,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0)))
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq))
val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
index 1913b8d425519..f0736348940ca 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
@@ -31,8 +31,8 @@ import org.openqa.selenium.{By, WebDriver}
import org.openqa.selenium.htmlunit.HtmlUnitDriver
import org.scalatest._
import org.scalatest.concurrent.Eventually._
-import org.scalatest.selenium.WebBrowser
import org.scalatest.time.SpanSugar._
+import org.scalatestplus.selenium.WebBrowser
import org.w3c.css.sac.CSSParseException
import org.apache.spark._
@@ -233,7 +233,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
test("spark.ui.killEnabled should properly control kill button display") {
def hasKillLink: Boolean = find(className("kill-link")).isDefined
- def runSlowJob(sc: SparkContext) {
+ def runSlowJob(sc: SparkContext): Unit = {
sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync()
}
@@ -316,10 +316,12 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
val env = SparkEnv.get
val bmAddress = env.blockManager.blockManagerId
val shuffleId = shuffleHandle.shuffleId
- val mapId = 0
+ val mapId = 0L
+ val mapIndex = 0
val reduceId = taskContext.partitionId()
val message = "Simulated fetch failure"
- throw new FetchFailedException(bmAddress, shuffleId, mapId, reduceId, message)
+ throw new FetchFailedException(
+ bmAddress, shuffleId, mapId, mapIndex, reduceId, message)
} else {
x
}
diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
index de105b6f188f5..82773e3cc6860 100644
--- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ui
import scala.xml.{Node, Text}
+import scala.xml.Utility.trim
import org.apache.spark.SparkFunSuite
@@ -129,6 +130,55 @@ class UIUtilsSuite extends SparkFunSuite {
assert(decoded1 === decodeURLParameter(decoded1))
}
+ test("listingTable with tooltips") {
+
+ def generateDataRowValue: String => Seq[Node] = row => {row}
+ val header = Seq("Header1", "Header2")
+ val data = Seq("Data1", "Data2")
+ val tooltip = Seq(None, Some("tooltip"))
+
+ val generated = listingTable(header, generateDataRowValue, data, tooltipHeaders = tooltip)
+
+ val expected: Node =
+
+
+ {header(0)}
+
+
+ {header(1)}
+
+
+
+
+ {data.map(generateDataRowValue)}
+
+
+
+ assert(trim(generated(0)) == trim(expected))
+ }
+
+ test("listingTable without tooltips") {
+
+ def generateDataRowValue: String => Seq[Node] = row => {row}
+ val header = Seq("Header1", "Header2")
+ val data = Seq("Data1", "Data2")
+
+ val generated = listingTable(header, generateDataRowValue, data)
+
+ val expected =
+
+
+ {header(0)}
+ {header(1)}
+
+
+ {data.map(generateDataRowValue)}
+
+
+
+ assert(trim(generated(0)) == trim(expected))
+ }
+
private def verify(
desc: String,
expected: Node,
diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
index f5f93ece660b8..21e69550785a4 100644
--- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
@@ -356,7 +356,7 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging {
}
/** Delete all the generated rolled over files */
- def cleanup() {
+ def cleanup(): Unit = {
testFile.getParentFile.listFiles.filter { file =>
file.getName.startsWith(testFile.getName)
}.foreach { _.delete() }
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index e781c5f71faf4..a2a4b3aa974fc 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -179,7 +179,7 @@ class JsonProtocolSuite extends SparkFunSuite {
testJobResult(jobFailed)
// TaskEndReason
- val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19,
+ val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 16L, 18, 19,
"Some exception")
val fetchMetadataFailed = new MetadataFetchFailedException(17,
19, "metadata Fetch failed exception").toTaskFailedReason
@@ -296,12 +296,12 @@ class JsonProtocolSuite extends SparkFunSuite {
test("FetchFailed backwards compatibility") {
// FetchFailed in Spark 1.1.0 does not have a "Message" property.
- val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19,
+ val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 16L, 18, 19,
"ignored")
val oldEvent = JsonProtocol.taskEndReasonToJson(fetchFailed)
.removeField({ _._1 == "Message" })
- val expectedFetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19,
- "Unknown reason")
+ val expectedFetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 16L,
+ 18, 19, "Unknown reason")
assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent))
}
@@ -496,59 +496,59 @@ private[spark] object JsonProtocolSuite extends Assertions {
private val nodeBlacklistedTime = 1421458952000L
private val nodeUnblacklistedTime = 1421458962000L
- private def testEvent(event: SparkListenerEvent, jsonString: String) {
+ private def testEvent(event: SparkListenerEvent, jsonString: String): Unit = {
val actualJsonString = compact(render(JsonProtocol.sparkEventToJson(event)))
val newEvent = JsonProtocol.sparkEventFromJson(parse(actualJsonString))
assertJsonStringEquals(jsonString, actualJsonString, event.getClass.getSimpleName)
assertEquals(event, newEvent)
}
- private def testRDDInfo(info: RDDInfo) {
+ private def testRDDInfo(info: RDDInfo): Unit = {
val newInfo = JsonProtocol.rddInfoFromJson(JsonProtocol.rddInfoToJson(info))
assertEquals(info, newInfo)
}
- private def testStageInfo(info: StageInfo) {
+ private def testStageInfo(info: StageInfo): Unit = {
val newInfo = JsonProtocol.stageInfoFromJson(JsonProtocol.stageInfoToJson(info))
assertEquals(info, newInfo)
}
- private def testStorageLevel(level: StorageLevel) {
+ private def testStorageLevel(level: StorageLevel): Unit = {
val newLevel = JsonProtocol.storageLevelFromJson(JsonProtocol.storageLevelToJson(level))
assertEquals(level, newLevel)
}
- private def testTaskMetrics(metrics: TaskMetrics) {
+ private def testTaskMetrics(metrics: TaskMetrics): Unit = {
val newMetrics = JsonProtocol.taskMetricsFromJson(JsonProtocol.taskMetricsToJson(metrics))
assertEquals(metrics, newMetrics)
}
- private def testBlockManagerId(id: BlockManagerId) {
+ private def testBlockManagerId(id: BlockManagerId): Unit = {
val newId = JsonProtocol.blockManagerIdFromJson(JsonProtocol.blockManagerIdToJson(id))
assert(id === newId)
}
- private def testTaskInfo(info: TaskInfo) {
+ private def testTaskInfo(info: TaskInfo): Unit = {
val newInfo = JsonProtocol.taskInfoFromJson(JsonProtocol.taskInfoToJson(info))
assertEquals(info, newInfo)
}
- private def testJobResult(result: JobResult) {
+ private def testJobResult(result: JobResult): Unit = {
val newResult = JsonProtocol.jobResultFromJson(JsonProtocol.jobResultToJson(result))
assertEquals(result, newResult)
}
- private def testTaskEndReason(reason: TaskEndReason) {
+ private def testTaskEndReason(reason: TaskEndReason): Unit = {
val newReason = JsonProtocol.taskEndReasonFromJson(JsonProtocol.taskEndReasonToJson(reason))
assertEquals(reason, newReason)
}
- private def testBlockId(blockId: BlockId) {
+ private def testBlockId(blockId: BlockId): Unit = {
val newBlockId = BlockId(blockId.toString)
assert(blockId === newBlockId)
}
- private def testExecutorInfo(info: ExecutorInfo) {
+ private def testExecutorInfo(info: ExecutorInfo): Unit = {
val newInfo = JsonProtocol.executorInfoFromJson(JsonProtocol.executorInfoToJson(info))
assertEquals(info, newInfo)
}
@@ -565,7 +565,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| Util methods for comparing events |
* --------------------------------- */
- private[spark] def assertEquals(event1: SparkListenerEvent, event2: SparkListenerEvent) {
+ private[spark] def assertEquals(event1: SparkListenerEvent, event2: SparkListenerEvent): Unit = {
(event1, event2) match {
case (e1: SparkListenerStageSubmitted, e2: SparkListenerStageSubmitted) =>
assert(e1.properties === e2.properties)
@@ -633,7 +633,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
}
}
- private def assertEquals(info1: StageInfo, info2: StageInfo) {
+ private def assertEquals(info1: StageInfo, info2: StageInfo): Unit = {
assert(info1.stageId === info2.stageId)
assert(info1.name === info2.name)
assert(info1.numTasks === info2.numTasks)
@@ -647,7 +647,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
assert(info1.details === info2.details)
}
- private def assertEquals(info1: RDDInfo, info2: RDDInfo) {
+ private def assertEquals(info1: RDDInfo, info2: RDDInfo): Unit = {
assert(info1.id === info2.id)
assert(info1.name === info2.name)
assert(info1.numPartitions === info2.numPartitions)
@@ -657,14 +657,14 @@ private[spark] object JsonProtocolSuite extends Assertions {
assertEquals(info1.storageLevel, info2.storageLevel)
}
- private def assertEquals(level1: StorageLevel, level2: StorageLevel) {
+ private def assertEquals(level1: StorageLevel, level2: StorageLevel): Unit = {
assert(level1.useDisk === level2.useDisk)
assert(level1.useMemory === level2.useMemory)
assert(level1.deserialized === level2.deserialized)
assert(level1.replication === level2.replication)
}
- private def assertEquals(info1: TaskInfo, info2: TaskInfo) {
+ private def assertEquals(info1: TaskInfo, info2: TaskInfo): Unit = {
assert(info1.taskId === info2.taskId)
assert(info1.index === info2.index)
assert(info1.attemptNumber === info2.attemptNumber)
@@ -679,12 +679,12 @@ private[spark] object JsonProtocolSuite extends Assertions {
assert(info1.accumulables === info2.accumulables)
}
- private def assertEquals(info1: ExecutorInfo, info2: ExecutorInfo) {
+ private def assertEquals(info1: ExecutorInfo, info2: ExecutorInfo): Unit = {
assert(info1.executorHost == info2.executorHost)
assert(info1.totalCores == info2.totalCores)
}
- private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) {
+ private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics): Unit = {
assert(metrics1.executorDeserializeTime === metrics2.executorDeserializeTime)
assert(metrics1.executorDeserializeCpuTime === metrics2.executorDeserializeCpuTime)
assert(metrics1.executorRunTime === metrics2.executorRunTime)
@@ -700,23 +700,23 @@ private[spark] object JsonProtocolSuite extends Assertions {
assertBlocksEquals(metrics1.updatedBlockStatuses, metrics2.updatedBlockStatuses)
}
- private def assertEquals(metrics1: ShuffleReadMetrics, metrics2: ShuffleReadMetrics) {
+ private def assertEquals(metrics1: ShuffleReadMetrics, metrics2: ShuffleReadMetrics): Unit = {
assert(metrics1.remoteBlocksFetched === metrics2.remoteBlocksFetched)
assert(metrics1.localBlocksFetched === metrics2.localBlocksFetched)
assert(metrics1.fetchWaitTime === metrics2.fetchWaitTime)
assert(metrics1.remoteBytesRead === metrics2.remoteBytesRead)
}
- private def assertEquals(metrics1: ShuffleWriteMetrics, metrics2: ShuffleWriteMetrics) {
+ private def assertEquals(metrics1: ShuffleWriteMetrics, metrics2: ShuffleWriteMetrics): Unit = {
assert(metrics1.bytesWritten === metrics2.bytesWritten)
assert(metrics1.writeTime === metrics2.writeTime)
}
- private def assertEquals(metrics1: InputMetrics, metrics2: InputMetrics) {
+ private def assertEquals(metrics1: InputMetrics, metrics2: InputMetrics): Unit = {
assert(metrics1.bytesRead === metrics2.bytesRead)
}
- private def assertEquals(result1: JobResult, result2: JobResult) {
+ private def assertEquals(result1: JobResult, result2: JobResult): Unit = {
(result1, result2) match {
case (JobSucceeded, JobSucceeded) =>
case (r1: JobFailed, r2: JobFailed) =>
@@ -725,13 +725,14 @@ private[spark] object JsonProtocolSuite extends Assertions {
}
}
- private def assertEquals(reason1: TaskEndReason, reason2: TaskEndReason) {
+ private def assertEquals(reason1: TaskEndReason, reason2: TaskEndReason): Unit = {
(reason1, reason2) match {
case (Success, Success) =>
case (Resubmitted, Resubmitted) =>
case (r1: FetchFailed, r2: FetchFailed) =>
assert(r1.shuffleId === r2.shuffleId)
assert(r1.mapId === r2.mapId)
+ assert(r1.mapIndex === r2.mapIndex)
assert(r1.reduceId === r2.reduceId)
assert(r1.bmAddress === r2.bmAddress)
assert(r1.message === r2.message)
@@ -761,7 +762,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
private def assertEquals(
details1: Map[String, Seq[(String, String)]],
- details2: Map[String, Seq[(String, String)]]) {
+ details2: Map[String, Seq[(String, String)]]): Unit = {
details1.zip(details2).foreach {
case ((key1, values1: Seq[(String, String)]), (key2, values2: Seq[(String, String)])) =>
assert(key1 === key2)
@@ -769,7 +770,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
}
}
- private def assertEquals(exception1: Exception, exception2: Exception) {
+ private def assertEquals(exception1: Exception, exception2: Exception): Unit = {
assert(exception1.getMessage === exception2.getMessage)
assertSeqEquals(
exception1.getStackTrace,
@@ -783,7 +784,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
}
}
- private def assertJsonStringEquals(expected: String, actual: String, metadata: String) {
+ private def assertJsonStringEquals(expected: String, actual: String, metadata: String): Unit = {
val expectedJson = parse(expected)
val actualJson = parse(actual)
if (expectedJson != actualJson) {
@@ -796,7 +797,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
}
}
- private def assertSeqEquals[T](seq1: Seq[T], seq2: Seq[T], assertEquals: (T, T) => Unit) {
+ private def assertSeqEquals[T](seq1: Seq[T], seq2: Seq[T], assertEquals: (T, T) => Unit): Unit = {
assert(seq1.length === seq2.length)
seq1.zip(seq2).foreach { case (t1, t2) =>
assertEquals(t1, t2)
@@ -806,7 +807,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
private def assertOptionEquals[T](
opt1: Option[T],
opt2: Option[T],
- assertEquals: (T, T) => Unit) {
+ assertEquals: (T, T) => Unit): Unit = {
if (opt1.isDefined) {
assert(opt2.isDefined)
assertEquals(opt1.get, opt2.get)
@@ -825,11 +826,12 @@ private[spark] object JsonProtocolSuite extends Assertions {
assertSeqEquals(blocks1, blocks2, assertBlockEquals)
}
- private def assertBlockEquals(b1: (BlockId, BlockStatus), b2: (BlockId, BlockStatus)) {
+ private def assertBlockEquals(b1: (BlockId, BlockStatus), b2: (BlockId, BlockStatus)): Unit = {
assert(b1 === b2)
}
- private def assertStackTraceElementEquals(ste1: StackTraceElement, ste2: StackTraceElement) {
+ private def assertStackTraceElementEquals(ste1: StackTraceElement,
+ ste2: StackTraceElement): Unit = {
// This mimics the equals() method from Java 8 and earlier. Java 9 adds checks for
// class loader and module, which will cause them to be not equal, when we don't
// care about those
@@ -936,6 +938,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
t.setExecutorDeserializeCpuTime(a)
t.setExecutorRunTime(b)
t.setExecutorCpuTime(b)
+ t.setPeakExecutionMemory(c)
t.setResultSize(c)
t.setJvmGCTime(d)
t.setResultSerializationTime(a + b)
@@ -1241,6 +1244,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Executor Deserialize CPU Time": 300,
| "Executor Run Time": 400,
| "Executor CPU Time": 400,
+ | "Peak Execution Memory": 500,
| "Result Size": 500,
| "JVM GC Time": 600,
| "Result Serialization Time": 700,
@@ -1364,6 +1368,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Executor Deserialize CPU Time": 300,
| "Executor Run Time": 400,
| "Executor CPU Time": 400,
+ | "Peak Execution Memory": 500,
| "Result Size": 500,
| "JVM GC Time": 600,
| "Result Serialization Time": 700,
@@ -1487,6 +1492,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Executor Deserialize CPU Time": 300,
| "Executor Run Time": 400,
| "Executor CPU Time": 400,
+ | "Peak Execution Memory": 500,
| "Result Size": 500,
| "JVM GC Time": 600,
| "Result Serialization Time": 700,
@@ -2050,7 +2056,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| {
| "ID": 9,
| "Name": "$PEAK_EXECUTION_MEMORY",
- | "Update": 0,
+ | "Update": 500,
| "Internal": true,
| "Count Failed Values": true
| },
diff --git a/core/src/test/scala/org/apache/spark/util/KeyLockSuite.scala b/core/src/test/scala/org/apache/spark/util/KeyLockSuite.scala
index 2169a0e4d442f..6888e492a8d33 100644
--- a/core/src/test/scala/org/apache/spark/util/KeyLockSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/KeyLockSuite.scala
@@ -49,7 +49,7 @@ class KeyLockSuite extends SparkFunSuite with TimeLimits {
@volatile var e: Throwable = null
val threads = (0 until numThreads).map { i =>
new Thread() {
- override def run(): Unit = try {
+ override def run(): Unit = {
latch.await(foreverMs, TimeUnit.MILLISECONDS)
keyLock.withLock(keys(i)) {
var cur = numThreadsHoldingLock.get()
diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala
index 4b7164d8acbce..1efd399b5db68 100644
--- a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala
@@ -81,7 +81,7 @@ class NextIteratorSuite extends SparkFunSuite with Matchers {
}
}
- override def close() {
+ override def close(): Unit = {
closeCalled += 1
}
}
diff --git a/core/src/test/scala/org/apache/spark/util/PropertiesCloneBenchmark.scala b/core/src/test/scala/org/apache/spark/util/PropertiesCloneBenchmark.scala
new file mode 100644
index 0000000000000..0726886c70fe6
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/PropertiesCloneBenchmark.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.util
+
+import java.util.Properties
+
+import scala.util.Random
+
+import org.apache.commons.lang.SerializationUtils
+
+import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
+
+
+/**
+ * Benchmark for Kryo Unsafe vs safe Serialization.
+ * To run this benchmark:
+ * {{{
+ * 1. without sbt:
+ * bin/spark-submit --class --jars
+ * 2. build/sbt "core/test:runMain "
+ * 3. generate result:
+ * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "core/test:runMain "
+ * Results will be written to "benchmarks/PropertiesCloneBenchmark-results.txt".
+ * }}}
+ */
+object PropertiesCloneBenchmark extends BenchmarkBase {
+ /**
+ * Benchmark various cases of cloning properties objects
+ */
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ runBenchmark("Properties Cloning") {
+ def compareSerialization(name: String, props: Properties): Unit = {
+ val benchmark = new Benchmark(name, 1, output = output)
+ benchmark.addCase("SerializationUtils.clone") { _ =>
+ SerializationUtils.clone(props)
+ }
+ benchmark.addCase("Utils.cloneProperties") { _ =>
+ Utils.cloneProperties(props)
+ }
+ benchmark.run()
+ }
+ compareSerialization("Empty Properties", new Properties)
+ compareSerialization("System Properties", System.getProperties)
+ compareSerialization("Small Properties", makeRandomProps(10, 40, 100))
+ compareSerialization("Medium Properties", makeRandomProps(50, 40, 100))
+ compareSerialization("Large Properties", makeRandomProps(100, 40, 100))
+ }
+ }
+
+ def makeRandomProps(numProperties: Int, keySize: Int, valueSize: Int): Properties = {
+ val props = new Properties
+ for (_ <- 1 to numProperties) {
+ props.put(
+ Random.alphanumeric.take(keySize),
+ Random.alphanumeric.take(valueSize)
+ )
+ }
+ props
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala
index 75e4504850679..0b1796540abbb 100644
--- a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala
+++ b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala
@@ -19,7 +19,6 @@ package org.apache.spark.util
import java.util.Properties
-import org.apache.commons.lang3.SerializationUtils
import org.scalatest.{BeforeAndAfterEach, Suite}
/**
@@ -43,11 +42,11 @@ private[spark] trait ResetSystemProperties extends BeforeAndAfterEach { this: Su
var oldProperties: Properties = null
override def beforeEach(): Unit = {
- // we need SerializationUtils.clone instead of `new Properties(System.getProperties())` because
+ // we need Utils.cloneProperties instead of `new Properties(System.getProperties())` because
// the later way of creating a copy does not copy the properties but it initializes a new
// Properties object with the given properties as defaults. They are not recognized at all
// by standard Scala wrapper over Java Properties then.
- oldProperties = SerializationUtils.clone(System.getProperties)
+ oldProperties = Utils.cloneProperties(System.getProperties)
super.beforeEach()
}
diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
index 8bc62db81e4f9..73bf7762f37c1 100644
--- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
@@ -73,7 +73,7 @@ class SizeEstimatorSuite
with PrivateMethodTester
with ResetSystemProperties {
- override def beforeEach() {
+ override def beforeEach(): Unit = {
// Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
super.beforeEach()
System.setProperty("os.arch", "amd64")
diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
index aa3f062e582c3..ac36e537c75bb 100644
--- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
@@ -132,7 +132,7 @@ class ThreadUtilsSuite extends SparkFunSuite {
val t = new Thread() {
setDaemon(true)
- override def run() {
+ override def run(): Unit = {
try {
// "par" is uninterruptible. The following will keep running even if the thread is
// interrupted. We should prefer to use "ThreadUtils.parmap".
diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala
index 77a92e7e1eb43..1644540946839 100644
--- a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala
@@ -63,7 +63,7 @@ class TimeStampedHashMapSuite extends SparkFunSuite {
}
/** Test basic operations of a Scala mutable Map. */
- def testMap(hashMapConstructor: => mutable.Map[String, String]) {
+ def testMap(hashMapConstructor: => mutable.Map[String, String]): Unit = {
def newMap() = hashMapConstructor
val testMap1 = newMap()
val testMap2 = newMap()
@@ -134,7 +134,7 @@ class TimeStampedHashMapSuite extends SparkFunSuite {
}
/** Test thread safety of a Scala mutable map. */
- def testMapThreadSafety(hashMapConstructor: => mutable.Map[String, String]) {
+ def testMapThreadSafety(hashMapConstructor: => mutable.Map[String, String]): Unit = {
def newMap() = hashMapConstructor
val name = newMap().getClass.getSimpleName
val testMap = newMap()
@@ -150,7 +150,7 @@ class TimeStampedHashMapSuite extends SparkFunSuite {
}
val threads = (1 to 25).map(i => new Thread() {
- override def run() {
+ override def run(): Unit = {
try {
for (j <- 1 to 1000) {
Random.nextInt(3) match {
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 2bad56d7ff424..a6de64b6c68a0 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -294,7 +294,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
| Helper methods that contain the test body |
* =========================================== */
- private def emptyDataStream(conf: SparkConf) {
+ private def emptyDataStream(conf: SparkConf): Unit = {
conf.set(SHUFFLE_MANAGER, "sort")
sc = new SparkContext("local", "test", conf)
val context = MemoryTestingUtils.fakeTaskContext(sc.env)
@@ -327,7 +327,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
sorter4.stop()
}
- private def fewElementsPerPartition(conf: SparkConf) {
+ private def fewElementsPerPartition(conf: SparkConf): Unit = {
conf.set(SHUFFLE_MANAGER, "sort")
sc = new SparkContext("local", "test", conf)
val context = MemoryTestingUtils.fakeTaskContext(sc.env)
@@ -368,7 +368,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
sorter4.stop()
}
- private def emptyPartitionsWithSpilling(conf: SparkConf) {
+ private def emptyPartitionsWithSpilling(conf: SparkConf): Unit = {
val size = 1000
conf.set(SHUFFLE_MANAGER, "sort")
conf.set(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD, size / 2)
@@ -393,7 +393,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
sorter.stop()
}
- private def testSpillingInLocalCluster(conf: SparkConf, numReduceTasks: Int) {
+ private def testSpillingInLocalCluster(conf: SparkConf, numReduceTasks: Int): Unit = {
val size = 5000
conf.set(SHUFFLE_MANAGER, "sort")
conf.set(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD, size / 4)
@@ -517,7 +517,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
conf: SparkConf,
withPartialAgg: Boolean,
withOrdering: Boolean,
- withSpilling: Boolean) {
+ withSpilling: Boolean): Unit = {
val size = 1000
if (withSpilling) {
conf.set(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD, size / 2)
@@ -551,7 +551,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
assert(results === expected)
}
- private def sortWithoutBreakingSortingContracts(conf: SparkConf) {
+ private def sortWithoutBreakingSortingContracts(conf: SparkConf): Unit = {
val size = 100000
val conf = createSparkConf(loadDefaults = true, kryo = false)
conf.set(SHUFFLE_MANAGER, "sort")
diff --git a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala
index 4759a830da4ca..8aa4be6c2ff8d 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala
@@ -71,7 +71,7 @@ class SizeTrackerSuite extends SparkFunSuite {
testMap[String, Int](10000, i => (randString(0, 10000), i))
}
- def testVector[T: ClassTag](numElements: Int, makeElement: Int => T) {
+ def testVector[T: ClassTag](numElements: Int, makeElement: Int => T): Unit = {
val vector = new SizeTrackingVector[T]
for (i <- 0 until numElements) {
val item = makeElement(i)
@@ -80,7 +80,7 @@ class SizeTrackerSuite extends SparkFunSuite {
}
}
- def testMap[K, V](numElements: Int, makeElement: (Int) => (K, V)) {
+ def testMap[K, V](numElements: Int, makeElement: (Int) => (K, V)): Unit = {
val map = new SizeTrackingAppendOnlyMap[K, V]
for (i <- 0 until numElements) {
val (k, v) = makeElement(i)
@@ -89,7 +89,7 @@ class SizeTrackerSuite extends SparkFunSuite {
}
}
- def expectWithinError(obj: AnyRef, estimatedSize: Long, error: Double) {
+ def expectWithinError(obj: AnyRef, estimatedSize: Long, error: Double): Unit = {
val betterEstimatedSize = SizeEstimator.estimate(obj)
assert(betterEstimatedSize * (1 - error) < estimatedSize,
s"Estimated size $estimatedSize was less than expected size $betterEstimatedSize")
diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala
index e80bd96c982df..d1603b85a8e94 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala
@@ -311,12 +311,13 @@ abstract class AbstractIntArraySortDataFormat[K] extends SortDataFormat[K, Array
data(pos1) = tmp
}
- override def copyElement(src: Array[Int], srcPos: Int, dst: Array[Int], dstPos: Int) {
+ override def copyElement(src: Array[Int], srcPos: Int, dst: Array[Int], dstPos: Int): Unit = {
dst(dstPos) = src(srcPos)
}
/** Copy a range of elements starting at src(srcPos) to dest, starting at destPos. */
- override def copyRange(src: Array[Int], srcPos: Int, dst: Array[Int], dstPos: Int, length: Int) {
+ override def copyRange(src: Array[Int], srcPos: Int,
+ dst: Array[Int], dstPos: Int, length: Int): Unit = {
System.arraycopy(src, srcPos, dst, dstPos, length)
}
@@ -334,13 +335,13 @@ abstract class AbstractByteArraySortDataFormat[K] extends SortDataFormat[K, Arra
data(pos1) = tmp
}
- override def copyElement(src: Array[Byte], srcPos: Int, dst: Array[Byte], dstPos: Int) {
+ override def copyElement(src: Array[Byte], srcPos: Int, dst: Array[Byte], dstPos: Int): Unit = {
dst(dstPos) = src(srcPos)
}
/** Copy a range of elements starting at src(srcPos) to dest, starting at destPos. */
override def copyRange(src: Array[Byte],
- srcPos: Int, dst: Array[Byte], dstPos: Int, length: Int) {
+ srcPos: Int, dst: Array[Byte], dstPos: Int, length: Int): Unit = {
System.arraycopy(src, srcPos, dst, dstPos, length)
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
index 38cb37c524594..a55004f664a54 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
@@ -20,12 +20,12 @@ package org.apache.spark.util.collection.unsafe.sort
import java.nio.charset.StandardCharsets
import com.google.common.primitives.UnsignedBytes
-import org.scalatest.prop.PropertyChecks
+import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks
import org.apache.spark.SparkFunSuite
import org.apache.spark.unsafe.types.UTF8String
-class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
+class PrefixComparatorsSuite extends SparkFunSuite with ScalaCheckPropertyChecks {
test("String prefix comparator") {
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
index a3c006b43d8e4..9ae6a8ef879f3 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
@@ -108,7 +108,8 @@ class RadixSortSuite extends SparkFunSuite with Logging {
}
}
- private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) {
+ private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long,
+ refCmp: PrefixComparator): Unit = {
val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
buf, Ints.checkedCast(lo), Ints.checkedCast(hi),
diff --git a/core/src/test/scala/org/apache/spark/util/logging/DriverLoggerSuite.scala b/core/src/test/scala/org/apache/spark/util/logging/DriverLoggerSuite.scala
index 973f71cdeb755..bd7ec242a9317 100644
--- a/core/src/test/scala/org/apache/spark/util/logging/DriverLoggerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/logging/DriverLoggerSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.util.logging
-import java.io.{BufferedInputStream, File, FileInputStream}
+import java.io.File
import org.apache.commons.io.FileUtils
diff --git a/dev/appveyor-install-dependencies.ps1 b/dev/appveyor-install-dependencies.ps1
index d33a107cc86a5..e0976e66db29f 100644
--- a/dev/appveyor-install-dependencies.ps1
+++ b/dev/appveyor-install-dependencies.ps1
@@ -90,7 +90,7 @@ Invoke-Expression "7z.exe x maven.zip"
# add maven to environment variables
$env:PATH = "$tools\apache-maven-$mavenVer\bin;" + $env:PATH
$env:M2_HOME = "$tools\apache-maven-$mavenVer"
-$env:MAVEN_OPTS = "-Xmx2g -XX:ReservedCodeCacheSize=512m"
+$env:MAVEN_OPTS = "-Xmx2g -XX:ReservedCodeCacheSize=1g"
Pop-Location
diff --git a/dev/create-release/do-release-docker.sh b/dev/create-release/do-release-docker.sh
index c1a122ebfb12e..f643c060eb321 100755
--- a/dev/create-release/do-release-docker.sh
+++ b/dev/create-release/do-release-docker.sh
@@ -127,6 +127,7 @@ GPG_KEY=$GPG_KEY
ASF_PASSWORD=$ASF_PASSWORD
GPG_PASSPHRASE=$GPG_PASSPHRASE
RELEASE_STEP=$RELEASE_STEP
+USER=$USER
EOF
JAVA_VOL=
diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh
index f35bc4f48652b..61951e73f4bab 100755
--- a/dev/create-release/release-build.sh
+++ b/dev/create-release/release-build.sh
@@ -164,7 +164,6 @@ DEST_DIR_NAME="$SPARK_PACKAGE_VERSION"
git clean -d -f -x
rm .gitignore
-rm -rf .git
cd ..
if [[ "$1" == "package" ]]; then
@@ -179,7 +178,7 @@ if [[ "$1" == "package" ]]; then
rm -r spark-$SPARK_VERSION/licenses-binary
fi
- tar cvzf spark-$SPARK_VERSION.tgz spark-$SPARK_VERSION
+ tar cvzf spark-$SPARK_VERSION.tgz --exclude spark-$SPARK_VERSION/.git spark-$SPARK_VERSION
echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour --output spark-$SPARK_VERSION.tgz.asc \
--detach-sig spark-$SPARK_VERSION.tgz
echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index 96cc76d0f2abb..3d15fc627f6de 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -1,8 +1,11 @@
+JLargeArrays-1.5.jar
+JTransforms-3.1.jar
JavaEWAH-0.3.2.jar
RoaringBitmap-0.7.45.jar
ST4-4.0.4.jar
activation-1.1.1.jar
aircompressor-0.10.jar
+algebra_2.12-2.0.0-M2.jar
antlr-2.7.7.jar
antlr-runtime-3.4.jar
antlr4-runtime-4.7.1.jar
@@ -17,13 +20,15 @@ arpack_combined_all-0.1.jar
arrow-format-0.12.0.jar
arrow-memory-0.12.0.jar
arrow-vector-0.12.0.jar
+audience-annotations-0.5.0.jar
automaton-1.11-8.jar
avro-1.8.2.jar
avro-ipc-1.8.2.jar
avro-mapred-1.8.2-hadoop2.jar
bonecp-0.8.0.RELEASE.jar
-breeze-macros_2.12-0.13.2.jar
-breeze_2.12-0.13.2.jar
+breeze-macros_2.12-1.0.jar
+breeze_2.12-1.0.jar
+cats-kernel_2.12-2.0.0-M4.jar
chill-java-0.9.3.jar
chill_2.12-0.9.3.jar
commons-beanutils-1.9.3.jar
@@ -84,16 +89,16 @@ httpclient-4.5.6.jar
httpcore-4.4.10.jar
istack-commons-runtime-3.0.8.jar
ivy-2.4.0.jar
-jackson-annotations-2.9.9.jar
-jackson-core-2.9.9.jar
+jackson-annotations-2.9.10.jar
+jackson-core-2.9.10.jar
jackson-core-asl-1.9.13.jar
-jackson-databind-2.9.9.3.jar
-jackson-dataformat-yaml-2.9.9.jar
+jackson-databind-2.9.10.jar
+jackson-dataformat-yaml-2.9.10.jar
jackson-jaxrs-1.9.13.jar
jackson-mapper-asl-1.9.13.jar
-jackson-module-jaxb-annotations-2.9.9.jar
-jackson-module-paranamer-2.9.9.jar
-jackson-module-scala_2.12-2.9.9.jar
+jackson-module-jaxb-annotations-2.9.10.jar
+jackson-module-paranamer-2.9.10.jar
+jackson-module-scala_2.12-2.9.10.jar
jackson-xc-1.9.13.jar
jakarta.annotation-api-1.3.4.jar
jakarta.inject-2.5.0.jar
@@ -130,7 +135,6 @@ json4s-scalap_2.12-3.6.6.jar
jsp-api-2.1.jar
jsr305-3.0.0.jar
jta-1.1.jar
-jtransforms-2.4.0.jar
jul-to-slf4j-1.7.16.jar
kryo-shaded-4.0.2.jar
kubernetes-client-4.4.2.jar
@@ -142,7 +146,7 @@ libthrift-0.12.0.jar
log4j-1.2.17.jar
logging-interceptor-3.12.0.jar
lz4-java-1.6.0.jar
-machinist_2.12-0.6.1.jar
+machinist_2.12-0.6.8.jar
macro-compat_2.12-1.1.1.jar
mesos-1.4.0-shaded-protobuf.jar
metrics-core-3.1.5.jar
@@ -156,9 +160,9 @@ okapi-shade-0.4.2.jar
okhttp-3.8.1.jar
okio-1.13.0.jar
opencsv-2.3.jar
-orc-core-1.5.5-nohive.jar
-orc-mapreduce-1.5.5-nohive.jar
-orc-shims-1.5.5.jar
+orc-core-1.5.6-nohive.jar
+orc-mapreduce-1.5.6-nohive.jar
+orc-shims-1.5.6.jar
oro-2.0.8.jar
osgi-resource-locator-1.0.3.jar
paranamer-2.8.jar
@@ -172,20 +176,23 @@ parquet-jackson-1.10.1.jar
protobuf-java-2.5.0.jar
py4j-0.10.8.1.jar
pyrolite-4.30.jar
-scala-compiler-2.12.8.jar
-scala-library-2.12.8.jar
-scala-parser-combinators_2.12-1.1.0.jar
-scala-reflect-2.12.8.jar
+scala-collection-compat_2.12-2.1.1.jar
+scala-compiler-2.12.10.jar
+scala-library-2.12.10.jar
+scala-parser-combinators_2.12-1.1.2.jar
+scala-reflect-2.12.10.jar
scala-xml_2.12-1.2.0.jar
-shapeless_2.12-2.3.2.jar
+shapeless_2.12-2.3.3.jar
shims-0.7.45.jar
slf4j-api-1.7.16.jar
slf4j-log4j12-1.7.16.jar
snakeyaml-1.23.jar
snappy-0.2.jar
snappy-java-1.1.7.3.jar
-spire-macros_2.12-0.13.0.jar
-spire_2.12-0.13.0.jar
+spire-macros_2.12-0.17.0-M1.jar
+spire-platform_2.12-0.17.0-M1.jar
+spire-util_2.12-0.17.0-M1.jar
+spire_2.12-0.17.0-M1.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.9.6.jar
@@ -198,5 +205,5 @@ xercesImpl-2.9.1.jar
xmlenc-0.52.jar
xz-1.5.jar
zjsonpatch-0.3.0.jar
-zookeeper-3.4.6.jar
-zstd-jni-1.4.2-1.jar
+zookeeper-3.4.14.jar
+zstd-jni-1.4.3-1.jar
diff --git a/dev/deps/spark-deps-hadoop-3.2 b/dev/deps/spark-deps-hadoop-3.2
index a3a5b51226462..6318217d4332f 100644
--- a/dev/deps/spark-deps-hadoop-3.2
+++ b/dev/deps/spark-deps-hadoop-3.2
@@ -1,9 +1,12 @@
+JLargeArrays-1.5.jar
+JTransforms-3.1.jar
JavaEWAH-0.3.2.jar
RoaringBitmap-0.7.45.jar
ST4-4.0.4.jar
accessors-smart-1.2.jar
activation-1.1.1.jar
aircompressor-0.10.jar
+algebra_2.12-2.0.0-M2.jar
antlr-2.7.7.jar
antlr-runtime-3.4.jar
antlr4-runtime-4.7.1.jar
@@ -20,8 +23,9 @@ avro-1.8.2.jar
avro-ipc-1.8.2.jar
avro-mapred-1.8.2-hadoop2.jar
bonecp-0.8.0.RELEASE.jar
-breeze-macros_2.12-0.13.2.jar
-breeze_2.12-0.13.2.jar
+breeze-macros_2.12-1.0.jar
+breeze_2.12-1.0.jar
+cats-kernel_2.12-2.0.0-M4.jar
chill-java-0.9.3.jar
chill_2.12-0.9.3.jar
commons-beanutils-1.9.3.jar
@@ -85,17 +89,17 @@ httpclient-4.5.6.jar
httpcore-4.4.10.jar
istack-commons-runtime-3.0.8.jar
ivy-2.4.0.jar
-jackson-annotations-2.9.9.jar
-jackson-core-2.9.9.jar
+jackson-annotations-2.9.10.jar
+jackson-core-2.9.10.jar
jackson-core-asl-1.9.13.jar
-jackson-databind-2.9.9.3.jar
-jackson-dataformat-yaml-2.9.9.jar
+jackson-databind-2.9.10.jar
+jackson-dataformat-yaml-2.9.10.jar
jackson-jaxrs-base-2.9.5.jar
jackson-jaxrs-json-provider-2.9.5.jar
jackson-mapper-asl-1.9.13.jar
-jackson-module-jaxb-annotations-2.9.9.jar
-jackson-module-paranamer-2.9.9.jar
-jackson-module-scala_2.12-2.9.9.jar
+jackson-module-jaxb-annotations-2.9.10.jar
+jackson-module-paranamer-2.9.10.jar
+jackson-module-scala_2.12-2.9.10.jar
jakarta.annotation-api-1.3.4.jar
jakarta.inject-2.5.0.jar
jakarta.ws.rs-api-2.1.5.jar
@@ -132,7 +136,6 @@ json4s-scalap_2.12-3.6.6.jar
jsp-api-2.1.jar
jsr305-3.0.0.jar
jta-1.1.jar
-jtransforms-2.4.0.jar
jul-to-slf4j-1.7.16.jar
kerb-admin-1.0.1.jar
kerb-client-1.0.1.jar
@@ -158,7 +161,7 @@ libthrift-0.12.0.jar
log4j-1.2.17.jar
logging-interceptor-3.12.0.jar
lz4-java-1.6.0.jar
-machinist_2.12-0.6.1.jar
+machinist_2.12-0.6.8.jar
macro-compat_2.12-1.1.1.jar
mesos-1.4.0-shaded-protobuf.jar
metrics-core-3.1.5.jar
@@ -175,9 +178,9 @@ okhttp-2.7.5.jar
okhttp-3.8.1.jar
okio-1.13.0.jar
opencsv-2.3.jar
-orc-core-1.5.5-nohive.jar
-orc-mapreduce-1.5.5-nohive.jar
-orc-shims-1.5.5.jar
+orc-core-1.5.6-nohive.jar
+orc-mapreduce-1.5.6-nohive.jar
+orc-shims-1.5.6.jar
oro-2.0.8.jar
osgi-resource-locator-1.0.3.jar
paranamer-2.8.jar
@@ -191,20 +194,23 @@ protobuf-java-2.5.0.jar
py4j-0.10.8.1.jar
pyrolite-4.30.jar
re2j-1.1.jar
-scala-compiler-2.12.8.jar
-scala-library-2.12.8.jar
-scala-parser-combinators_2.12-1.1.0.jar
-scala-reflect-2.12.8.jar
+scala-collection-compat_2.12-2.1.1.jar
+scala-compiler-2.12.10.jar
+scala-library-2.12.10.jar
+scala-parser-combinators_2.12-1.1.2.jar
+scala-reflect-2.12.10.jar
scala-xml_2.12-1.2.0.jar
-shapeless_2.12-2.3.2.jar
+shapeless_2.12-2.3.3.jar
shims-0.7.45.jar
slf4j-api-1.7.16.jar
slf4j-log4j12-1.7.16.jar
snakeyaml-1.23.jar
snappy-0.2.jar
snappy-java-1.1.7.3.jar
-spire-macros_2.12-0.13.0.jar
-spire_2.12-0.13.0.jar
+spire-macros_2.12-0.17.0-M1.jar
+spire-platform_2.12-0.17.0-M1.jar
+spire-util_2.12-0.17.0-M1.jar
+spire_2.12-0.17.0-M1.jar
stax-api-1.0.1.jar
stax2-api-3.1.4.jar
stream-2.9.6.jar
@@ -217,5 +223,5 @@ woodstox-core-5.0.3.jar
xbean-asm7-shaded-4.14.jar
xz-1.5.jar
zjsonpatch-0.3.0.jar
-zookeeper-3.4.13.jar
-zstd-jni-1.4.2-1.jar
+zookeeper-3.4.14.jar
+zstd-jni-1.4.3-1.jar
diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh
index a550af93feecd..cd18b6870e07c 100755
--- a/dev/make-distribution.sh
+++ b/dev/make-distribution.sh
@@ -160,7 +160,7 @@ fi
# Build uber fat JAR
cd "$SPARK_HOME"
-export MAVEN_OPTS="${MAVEN_OPTS:--Xmx2g -XX:ReservedCodeCacheSize=512m}"
+export MAVEN_OPTS="${MAVEN_OPTS:--Xmx2g -XX:ReservedCodeCacheSize=1g}"
# Store the command as an array because $MVN variable might have spaces in it.
# Normal quoting tricks don't work.
diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py
index fa3d50b8989f1..967cdace60dc9 100755
--- a/dev/merge_spark_pr.py
+++ b/dev/merge_spark_pr.py
@@ -97,9 +97,9 @@ def fail(msg):
def run_cmd(cmd):
print(cmd)
if isinstance(cmd, list):
- return subprocess.check_output(cmd).decode(sys.getdefaultencoding())
+ return subprocess.check_output(cmd).decode('utf-8')
else:
- return subprocess.check_output(cmd.split(" ")).decode(sys.getdefaultencoding())
+ return subprocess.check_output(cmd.split(" ")).decode('utf-8')
def continue_maybe(prompt):
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 0f6dbf2f99a97..c7ea065b28ed8 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -362,7 +362,6 @@ def __hash__(self):
"pyspark.sql.window",
"pyspark.sql.avro.functions",
# unittests
- "pyspark.sql.tests.test_appsubmit",
"pyspark.sql.tests.test_arrow",
"pyspark.sql.tests.test_catalog",
"pyspark.sql.tests.test_column",
@@ -373,6 +372,7 @@ def __hash__(self):
"pyspark.sql.tests.test_functions",
"pyspark.sql.tests.test_group",
"pyspark.sql.tests.test_pandas_udf",
+ "pyspark.sql.tests.test_pandas_udf_cogrouped_map",
"pyspark.sql.tests.test_pandas_udf_grouped_agg",
"pyspark.sql.tests.test_pandas_udf_grouped_map",
"pyspark.sql.tests.test_pandas_udf_scalar",
diff --git a/docs/_config.yml b/docs/_config.yml
index 146c90fcff6e5..57b8d716ee55c 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -17,7 +17,7 @@ include:
SPARK_VERSION: 3.0.0-SNAPSHOT
SPARK_VERSION_SHORT: 3.0.0
SCALA_BINARY_VERSION: "2.12"
-SCALA_VERSION: "2.12.8"
+SCALA_VERSION: "2.12.10"
MESOS_VERSION: 1.0.0
SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK
SPARK_GITHUB_URL: https://github.com/apache/spark
diff --git a/docs/_data/menu-migration.yaml b/docs/_data/menu-migration.yaml
new file mode 100644
index 0000000000000..1d8b311dd64fb
--- /dev/null
+++ b/docs/_data/menu-migration.yaml
@@ -0,0 +1,12 @@
+- text: Spark Core
+ url: core-migration-guide.html
+- text: SQL, Datasets and DataFrame
+ url: sql-migration-guide.html
+- text: Structured Streaming
+ url: ss-migration-guide.html
+- text: MLlib (Machine Learning)
+ url: ml-migration-guide.html
+- text: PySpark (Python on Spark)
+ url: pyspark-migration-guide.html
+- text: SparkR (R on Spark)
+ url: sparkr-migration-guide.html
diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml
index 717911b5a4645..edcdad4ee7db5 100644
--- a/docs/_data/menu-sql.yaml
+++ b/docs/_data/menu-sql.yaml
@@ -64,20 +64,14 @@
- text: Usage Notes
url: sql-pyspark-pandas-with-arrow.html#usage-notes
- text: Migration Guide
- url: sql-migration-guide.html
- subitems:
- - text: Spark SQL Upgrading Guide
- url: sql-migration-guide-upgrade.html
- - text: Compatibility with Apache Hive
- url: sql-migration-guide-hive-compatibility.html
- - text: SQL Reserved/Non-Reserved Keywords
- url: sql-reserved-and-non-reserved-keywords.html
-
+ url: sql-migration-old.html
- text: SQL Reference
url: sql-ref.html
subitems:
- text: Data Types
url: sql-ref-datatypes.html
+ - text: Null Semantics
+ url: sql-ref-null-semantics.html
- text: NaN Semantics
url: sql-ref-nan-semantics.html
- text: SQL Syntax
@@ -139,6 +133,8 @@
url: sql-ref-syntax-qry-select-limit.html
- text: Set operations
url: sql-ref-syntax-qry-select-setops.html
+ - text: USE database
+ url: sql-ref-syntax-qry-select-usedb.html
- text: Common Table Expression(CTE)
url: sql-ref-syntax-qry-select-cte.html
- text: Subqueries
@@ -170,6 +166,8 @@
url: sql-ref-syntax-aux-cache-uncache-table.html
- text: CLEAR CACHE
url: sql-ref-syntax-aux-cache-clear-cache.html
+ - text: REFRESH TABLE
+ url: sql-ref-syntax-aux-refresh-table.html
- text: Describe Commands
url: sql-ref-syntax-aux-describe.html
subitems:
diff --git a/docs/_includes/nav-left-wrapper-migration.html b/docs/_includes/nav-left-wrapper-migration.html
new file mode 100644
index 0000000000000..4318a324a9475
--- /dev/null
+++ b/docs/_includes/nav-left-wrapper-migration.html
@@ -0,0 +1,6 @@
+
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index 8ea15dc71d541..d5fb18bfb06c0 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -112,6 +112,7 @@
Job Scheduling
Security
Hardware Provisioning
+ Migration Guide
Building Spark
Contributing to Spark
@@ -126,8 +127,10 @@
- {% if page.url contains "/ml" or page.url contains "/sql" %}
- {% if page.url contains "/ml" %}
+ {% if page.url contains "/ml" or page.url contains "/sql" or page.url contains "migration-guide.html" %}
+ {% if page.url contains "migration-guide.html" %}
+ {% include nav-left-wrapper-migration.html nav-migration=site.data.menu-migration %}
+ {% elsif page.url contains "/ml" %}
{% include nav-left-wrapper-ml.html nav-mllib=site.data.menu-mllib nav-ml=site.data.menu-ml %}
{% else %}
{% include nav-left-wrapper-sql.html nav-sql=site.data.menu-sql %}
diff --git a/docs/building-spark.md b/docs/building-spark.md
index 37f898645da68..13f848eff88db 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -34,7 +34,7 @@ Spark requires Scala 2.12; support for Scala 2.11 was removed in Spark 3.0.0.
You'll need to configure Maven to use more memory than usual by setting `MAVEN_OPTS`:
- export MAVEN_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=512m"
+ export MAVEN_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=1g"
(The `ReservedCodeCacheSize` setting is optional but recommended.)
If you don't add these parameters to `MAVEN_OPTS`, you may see errors and warnings like the following:
@@ -82,7 +82,7 @@ Example:
## Building With Hive and JDBC Support
To enable Hive integration for Spark SQL along with its JDBC server and CLI,
-add the `-Phive` and `Phive-thriftserver` profiles to your existing build options.
+add the `-Phive` and `-Phive-thriftserver` profiles to your existing build options.
By default, Spark will use Hive 1.2.1 with the `hadoop-2.7` profile, and Hive 2.3.6 with the `hadoop-3.2` profile.
# With Hive 1.2.1 support
@@ -160,7 +160,7 @@ prompt.
Configure the JVM options for SBT in `.jvmopts` at the project root, for example:
-Xmx2g
- -XX:ReservedCodeCacheSize=512m
+ -XX:ReservedCodeCacheSize=1g
For the meanings of these two options, please carefully read the [Setting up Maven's Memory Usage section](https://spark.apache.org/docs/latest/building-spark.html#setting-up-mavens-memory-usage).
diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md
index a8d40fe7456e4..b2a3e77f1ee9d 100644
--- a/docs/cloud-integration.md
+++ b/docs/cloud-integration.md
@@ -257,4 +257,5 @@ Here is the documentation on the standard connectors both from Apache and the cl
* [Amazon EMR File System (EMRFS)](https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-fs.html). From Amazon
* [Google Cloud Storage Connector for Spark and Hadoop](https://cloud.google.com/hadoop/google-cloud-storage-connector). From Google
* [The Azure Blob Filesystem driver (ABFS)](https://docs.microsoft.com/en-us/azure/storage/blobs/data-lake-storage-abfs-driver)
+* IBM Cloud Object Storage connector for Apache Spark: [Stocator](https://github.com/CODAIT/stocator), [IBM Object Storage](https://www.ibm.com/cloud/object-storage), [how-to-use-connector](https://developer.ibm.com/code/2018/08/16/installing-running-stocator-apache-spark-ibm-cloud-object-storage). From IBM
diff --git a/docs/configuration.md b/docs/configuration.md
index 9933283cdad87..729b1ba7ed2ca 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -230,7 +230,7 @@ of the most common options to set are:
write to STDOUT a JSON string in the format of the ResourceInformation class. This has a
name and an array of addresses. For a client-submitted driver in Standalone, discovery
script must assign different resource addresses to this driver comparing to workers' and
- other dirvers' when spark.resources.coordinate.enable
is off.
+ other drivers' when spark.resources.coordinate.enable
is off.
@@ -411,6 +411,16 @@ of the most common options to set are:
use the default layout.
+
+ spark.driver.log.allowErasureCoding
+ false
+
+ Whether to allow driver logs to use erasure coding. On HDFS, erasure coded files will not
+ update as quickly as regular replicated files, so they make take longer to reflect changes
+ written by the application. Note that even if this is true, Spark will still not force the
+ file to use erasure coding, it will simply use file system defaults.
+
+
Apart from these, the following properties are also available, and may be useful in some situations:
@@ -866,7 +876,7 @@ Apart from these, the following properties are also available, and may be useful
spark.shuffle.service.index.cache.size
100m
- Cache entries limited to the specified memory footprint in bytes.
+ Cache entries limited to the specified memory footprint, in bytes unless otherwise specified.
@@ -1207,16 +1217,18 @@ Apart from these, the following properties are also available, and may be useful
spark.io.compression.lz4.blockSize
32k
- Block size in bytes used in LZ4 compression, in the case when LZ4 compression codec
+ Block size used in LZ4 compression, in the case when LZ4 compression codec
is used. Lowering this block size will also lower shuffle memory usage when LZ4 is used.
+ Default unit is bytes, unless otherwise specified.
spark.io.compression.snappy.blockSize
32k
- Block size in bytes used in Snappy compression, in the case when Snappy compression codec
- is used. Lowering this block size will also lower shuffle memory usage when Snappy is used.
+ Block size in Snappy compression, in the case when Snappy compression codec is used.
+ Lowering this block size will also lower shuffle memory usage when Snappy is used.
+ Default unit is bytes, unless otherwise specified.
@@ -1384,7 +1396,7 @@ Apart from these, the following properties are also available, and may be useful
spark.memory.offHeap.size
0
- The absolute amount of memory in bytes which can be used for off-heap allocation.
+ The absolute amount of memory which can be used for off-heap allocation, in bytes unless otherwise specified.
This setting has no impact on heap memory usage, so if your executors' total memory consumption
must fit within some hard limit then be sure to shrink your JVM heap size accordingly.
This must be set to a positive value when spark.memory.offHeap.enabled=true
.
@@ -1568,9 +1580,9 @@ Apart from these, the following properties are also available, and may be useful
spark.storage.memoryMapThreshold
2m
- Size in bytes of a block above which Spark memory maps when reading a block from disk.
- This prevents Spark from memory mapping very small blocks. In general, memory
- mapping has high overhead for blocks close to or below the page size of the operating system.
+ Size of a block above which Spark memory maps when reading a block from disk. Default unit is bytes,
+ unless specified otherwise. This prevents Spark from memory mapping very small blocks. In general,
+ memory mapping has high overhead for blocks close to or below the page size of the operating system.
@@ -2596,11 +2608,14 @@ You can copy and modify `hdfs-site.xml`, `core-site.xml`, `yarn-site.xml`, `hive
Spark's classpath for each application. In a Spark cluster running on YARN, these configuration
files are set cluster-wide, and cannot safely be changed by the application.
-The better choice is to use spark hadoop properties in the form of `spark.hadoop.*`.
+The better choice is to use spark hadoop properties in the form of `spark.hadoop.*`, and use
+spark hive properties in the form of `spark.hive.*`.
+For example, adding configuration "spark.hadoop.abc.def=xyz" represents adding hadoop property "abc.def=xyz",
+and adding configuration "spark.hive.abc=xyz" represents adding hive property "hive.abc=xyz".
They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-defaults.conf`
In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For
-instance, Spark allows you to simply create an empty conf and set spark/spark hadoop properties.
+instance, Spark allows you to simply create an empty conf and set spark/spark hadoop/spark hive properties.
{% highlight scala %}
val conf = new SparkConf().set("spark.hadoop.abc.def", "xyz")
@@ -2614,6 +2629,19 @@ Also, you can modify or add configurations at runtime:
--master local[4] \
--conf spark.eventLog.enabled=false \
--conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" \
- --conf spark.hadoop.abc.def=xyz \
+ --conf spark.hadoop.abc.def=xyz \
+ --conf spark.hive.abc=xyz
myApp.jar
{% endhighlight %}
+
+# Custom Resource Scheduling and Configuration Overview
+
+GPUs and other accelerators have been widely used for accelerating special workloads, e.g.,
+deep learning and signal processing. Spark now supports requesting and scheduling generic resources, such as GPUs, with a few caveats. The current implementation requires that the resource have addresses that can be allocated by the scheduler. It requires your cluster manager to support and be properly configured with the resources.
+
+There are configurations available to request resources for the driver: spark.driver.resource.{resourceName}.amount
, request resources for the executor(s): spark.executor.resource.{resourceName}.amount
and specify the requirements for each task: spark.task.resource.{resourceName}.amount
. The spark.driver.resource.{resourceName}.discoveryScript
config is required on YARN, Kubernetes and a client side Driver on Spark Standalone. spark.driver.executor.{resourceName}.discoveryScript
config is required for YARN and Kubernetes. Kubernetes also requires spark.driver.resource.{resourceName}.vendor
and/or spark.executor.resource.{resourceName}.vendor
. See the config descriptions above for more information on each.
+
+Spark will use the configurations specified to first request containers with the corresponding resources from the cluster manager. Once it gets the container, Spark launches an Executor in that container which will discover what resources the container has and the addresses associated with each resource. The Executor will register with the Driver and report back the resources available to that Executor. The Spark scheduler can then schedule tasks to each Executor and assign specific resource addresses based on the resource requirements the user specified. The user can see the resources assigned to a task using the TaskContext.get().resources
api. On the driver, the user can see the resources assigned with the SparkContext resources
call. It's then up to the user to use the assignedaddresses to do the processing they want or pass those into the ML/AI framework they are using.
+
+See your cluster manager specific page for requirements and details on each of - [YARN](running-on-yarn.html#resource-allocation-and-configuration-overview), [Kubernetes](running-on-kubernetes.html#resource-allocation-and-configuration-overview) and [Standalone Mode](spark-standalone.html#resource-allocation-and-configuration-overview). It is currently not available with Mesos or local mode. If using local-cluster mode see the Spark Standalone documentation but be aware only a single worker resources file or discovery script can be specified the is shared by all the Workers so you should enable resource coordination (see spark.resources.coordinate.enable
).
+
diff --git a/docs/core-migration-guide.md b/docs/core-migration-guide.md
new file mode 100644
index 0000000000000..2d4d91dab075e
--- /dev/null
+++ b/docs/core-migration-guide.md
@@ -0,0 +1,33 @@
+---
+layout: global
+title: "Migration Guide: Spark Core"
+displayTitle: "Migration Guide: Spark Core"
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+* Table of contents
+{:toc}
+
+## Upgrading from Core 2.4 to 3.0
+
+- In Spark 3.0, deprecated method `TaskContext.isRunningLocally` has been removed. Local execution was removed and it always has returned `false`.
+
+- In Spark 3.0, deprecated method `shuffleBytesWritten`, `shuffleWriteTime` and `shuffleRecordsWritten` in `ShuffleWriteMetrics` have been removed. Instead, use `bytesWritten`, `writeTime ` and `recordsWritten` respectively.
+
+- In Spark 3.0, deprecated method `AccumulableInfo.apply` have been removed because creating `AccumulableInfo` is disallowed.
+
+- In Spark 3.0, event log file will be written as UTF-8 encoding, and Spark History Server will replay event log files as UTF-8 encoding. Previously Spark writes event log file as default charset of driver JVM process, so Spark History Server of Spark 2.x is needed to read the old event log files in case of incompatible encoding.
\ No newline at end of file
diff --git a/docs/img/JDBCServer1.png b/docs/img/JDBCServer1.png
new file mode 100644
index 0000000000000..c568b199353ae
Binary files /dev/null and b/docs/img/JDBCServer1.png differ
diff --git a/docs/img/JDBCServer2.png b/docs/img/JDBCServer2.png
new file mode 100644
index 0000000000000..84008c78ef269
Binary files /dev/null and b/docs/img/JDBCServer2.png differ
diff --git a/docs/img/JDBCServer3.png b/docs/img/JDBCServer3.png
new file mode 100644
index 0000000000000..772c3cfdeb967
Binary files /dev/null and b/docs/img/JDBCServer3.png differ
diff --git a/docs/index.md b/docs/index.md
index 4217918a87462..edb1c421fb794 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -46,7 +46,7 @@ Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy
locally on one machine --- all you need is to have `java` installed on your system `PATH`,
or the `JAVA_HOME` environment variable pointing to a Java installation.
-Spark runs on Java 8, Scala 2.12, Python 2.7+/3.4+ and R 3.1+.
+Spark runs on Java 8/11, Scala 2.12, Python 2.7+/3.4+ and R 3.1+.
Python 2 support is deprecated as of Spark 3.0.0.
R prior to version 3.4 support is deprecated as of Spark 3.0.0.
For the Scala API, Spark {{site.SPARK_VERSION}}
@@ -146,6 +146,7 @@ options for deployment:
* Integration with other storage systems:
* [Cloud Infrastructures](cloud-integration.html)
* [OpenStack Swift](storage-openstack-swift.html)
+* [Migration Guide](migration-guide.html): Migration guides for Spark components
* [Building Spark](building-spark.html): build Spark using the Maven system
* [Contributing to Spark](https://spark.apache.org/contributing.html)
* [Third Party Projects](https://spark.apache.org/third-party-projects.html): related third party Spark projects
diff --git a/docs/migration-guide.md b/docs/migration-guide.md
new file mode 100644
index 0000000000000..9ca0ada37a2fe
--- /dev/null
+++ b/docs/migration-guide.md
@@ -0,0 +1,30 @@
+---
+layout: global
+title: Migration Guide
+displayTitle: Migration Guide
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+This page documents sections of the migration guide for each component in order
+for users to migrate effectively.
+
+* [Spark Core](core-migration-guide.html)
+* [SQL, Datasets, and DataFrame](sql-migration-guide.html)
+* [Structured Streaming](ss-migration-guide.html)
+* [MLlib (Machine Learning)](ml-migration-guide.html)
+* [PySpark (Python on Spark)](pyspark-migration-guide.html)
+* [SparkR (R on Spark)](sparkr-migration-guide.html)
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index 4661d6cd87c04..7b4fa4f651e64 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -113,68 +113,7 @@ transforming multiple columns.
* Robust linear regression with Huber loss
([SPARK-3181](https://issues.apache.org/jira/browse/SPARK-3181)).
-# Migration guide
+# Migration Guide
-MLlib is under active development.
-The APIs marked `Experimental`/`DeveloperApi` may change in future releases,
-and the migration guide below will explain all changes between releases.
+The migration guide is now archived [on this page](ml-migration-guide.html).
-## From 2.4 to 3.0
-
-### Breaking changes
-
-* `OneHotEncoder` which is deprecated in 2.3, is removed in 3.0 and `OneHotEncoderEstimator` is now renamed to `OneHotEncoder`.
-
-### Changes of behavior
-
-* [SPARK-11215](https://issues.apache.org/jira/browse/SPARK-11215):
- In Spark 2.4 and previous versions, when specifying `frequencyDesc` or `frequencyAsc` as
- `stringOrderType` param in `StringIndexer`, in case of equal frequency, the order of
- strings is undefined. Since Spark 3.0, the strings with equal frequency are further
- sorted by alphabet. And since Spark 3.0, `StringIndexer` supports encoding multiple
- columns.
-
-## From 2.2 to 2.3
-
-### Breaking changes
-
-* The class and trait hierarchy for logistic regression model summaries was changed to be cleaner
-and better accommodate the addition of the multi-class summary. This is a breaking change for user
-code that casts a `LogisticRegressionTrainingSummary` to a
-`BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary`
-method. See [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139) for more detail
-(_note_ this is an `Experimental` API). This _does not_ affect the Python `summary` method, which
-will still work correctly for both multinomial and binary cases.
-
-### Deprecations and changes of behavior
-
-**Deprecations**
-
-* `OneHotEncoder` has been deprecated and will be removed in `3.0`. It has been replaced by the
-new [`OneHotEncoderEstimator`](ml-features.html#onehotencoderestimator)
-(see [SPARK-13030](https://issues.apache.org/jira/browse/SPARK-13030)). **Note** that
-`OneHotEncoderEstimator` will be renamed to `OneHotEncoder` in `3.0` (but
-`OneHotEncoderEstimator` will be kept as an alias).
-
-**Changes of behavior**
-
-* [SPARK-21027](https://issues.apache.org/jira/browse/SPARK-21027):
- The default parallelism used in `OneVsRest` is now set to 1 (i.e. serial). In `2.2` and
- earlier versions, the level of parallelism was set to the default threadpool size in Scala.
-* [SPARK-22156](https://issues.apache.org/jira/browse/SPARK-22156):
- The learning rate update for `Word2Vec` was incorrect when `numIterations` was set greater than
- `1`. This will cause training results to be different between `2.3` and earlier versions.
-* [SPARK-21681](https://issues.apache.org/jira/browse/SPARK-21681):
- Fixed an edge case bug in multinomial logistic regression that resulted in incorrect coefficients
- when some features had zero variance.
-* [SPARK-16957](https://issues.apache.org/jira/browse/SPARK-16957):
- Tree algorithms now use mid-points for split values. This may change results from model training.
-* [SPARK-14657](https://issues.apache.org/jira/browse/SPARK-14657):
- Fixed an issue where the features generated by `RFormula` without an intercept were inconsistent
- with the output in R. This may change results from model training in this scenario.
-
-## Previous Spark versions
-
-Earlier migration guides are archived [on this page](ml-migration-guides.html).
-
----
diff --git a/docs/ml-migration-guides.md b/docs/ml-migration-guide.md
similarity index 85%
rename from docs/ml-migration-guides.md
rename to docs/ml-migration-guide.md
index 99edd9bd69efa..9e8cd3e07b1ee 100644
--- a/docs/ml-migration-guides.md
+++ b/docs/ml-migration-guide.md
@@ -1,8 +1,7 @@
---
layout: global
-title: Old Migration Guides - MLlib
-displayTitle: Old Migration Guides - MLlib
-description: MLlib migration guides from before Spark SPARK_VERSION_SHORT
+title: "Migration Guide: MLlib (Machine Learning)"
+displayTitle: "Migration Guide: MLlib (Machine Learning)"
license: |
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
@@ -20,15 +19,80 @@ license: |
limitations under the License.
---
-The migration guide for the current Spark version is kept on the [MLlib Guide main page](ml-guide.html#migration-guide).
+* Table of contents
+{:toc}
-## From 2.1 to 2.2
+Note that this migration guide describes the items specific to MLlib.
+Many items of SQL migration can be applied when migrating MLlib to higher versions for DataFrame-based APIs.
+Please refer [Migration Guide: SQL, Datasets and DataFrame](sql-migration-guide.html).
+
+## Upgrading from MLlib 2.4 to 3.0
+
+### Breaking changes
+{:.no_toc}
+
+* `OneHotEncoder` which is deprecated in 2.3, is removed in 3.0 and `OneHotEncoderEstimator` is now renamed to `OneHotEncoder`.
+
+### Changes of behavior
+{:.no_toc}
+
+* [SPARK-11215](https://issues.apache.org/jira/browse/SPARK-11215):
+ In Spark 2.4 and previous versions, when specifying `frequencyDesc` or `frequencyAsc` as
+ `stringOrderType` param in `StringIndexer`, in case of equal frequency, the order of
+ strings is undefined. Since Spark 3.0, the strings with equal frequency are further
+ sorted by alphabet. And since Spark 3.0, `StringIndexer` supports encoding multiple
+ columns.
+
+## Upgrading from MLlib 2.2 to 2.3
+
+### Breaking changes
+{:.no_toc}
+
+* The class and trait hierarchy for logistic regression model summaries was changed to be cleaner
+and better accommodate the addition of the multi-class summary. This is a breaking change for user
+code that casts a `LogisticRegressionTrainingSummary` to a
+`BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary`
+method. See [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139) for more detail
+(_note_ this is an `Experimental` API). This _does not_ affect the Python `summary` method, which
+will still work correctly for both multinomial and binary cases.
+
+### Deprecations and changes of behavior
+{:.no_toc}
+
+**Deprecations**
+
+* `OneHotEncoder` has been deprecated and will be removed in `3.0`. It has been replaced by the
+new [`OneHotEncoderEstimator`](ml-features.html#onehotencoderestimator)
+(see [SPARK-13030](https://issues.apache.org/jira/browse/SPARK-13030)). **Note** that
+`OneHotEncoderEstimator` will be renamed to `OneHotEncoder` in `3.0` (but
+`OneHotEncoderEstimator` will be kept as an alias).
+
+**Changes of behavior**
+
+* [SPARK-21027](https://issues.apache.org/jira/browse/SPARK-21027):
+ The default parallelism used in `OneVsRest` is now set to 1 (i.e. serial). In `2.2` and
+ earlier versions, the level of parallelism was set to the default threadpool size in Scala.
+* [SPARK-22156](https://issues.apache.org/jira/browse/SPARK-22156):
+ The learning rate update for `Word2Vec` was incorrect when `numIterations` was set greater than
+ `1`. This will cause training results to be different between `2.3` and earlier versions.
+* [SPARK-21681](https://issues.apache.org/jira/browse/SPARK-21681):
+ Fixed an edge case bug in multinomial logistic regression that resulted in incorrect coefficients
+ when some features had zero variance.
+* [SPARK-16957](https://issues.apache.org/jira/browse/SPARK-16957):
+ Tree algorithms now use mid-points for split values. This may change results from model training.
+* [SPARK-14657](https://issues.apache.org/jira/browse/SPARK-14657):
+ Fixed an issue where the features generated by `RFormula` without an intercept were inconsistent
+ with the output in R. This may change results from model training in this scenario.
+
+## Upgrading from MLlib 2.1 to 2.2
### Breaking changes
+{:.no_toc}
There are no breaking changes.
### Deprecations and changes of behavior
+{:.no_toc}
**Deprecations**
@@ -45,9 +109,10 @@ There are no deprecations.
`StringIndexer` now handles `NULL` values in the same way as unseen values. Previously an exception
would always be thrown regardless of the setting of the `handleInvalid` parameter.
-## From 2.0 to 2.1
+## Upgrading from MLlib 2.0 to 2.1
### Breaking changes
+{:.no_toc}
**Deprecated methods removed**
@@ -59,6 +124,7 @@ There are no deprecations.
* `validateParams` in `Evaluator`
### Deprecations and changes of behavior
+{:.no_toc}
**Deprecations**
@@ -74,9 +140,10 @@ There are no deprecations.
* [SPARK-17389](https://issues.apache.org/jira/browse/SPARK-17389):
`KMeans` reduces the default number of steps from 5 to 2 for the k-means|| initialization mode.
-## From 1.6 to 2.0
+## Upgrading from MLlib 1.6 to 2.0
### Breaking changes
+{:.no_toc}
There were several breaking changes in Spark 2.0, which are outlined below.
@@ -171,6 +238,7 @@ Several deprecated methods were removed in the `spark.mllib` and `spark.ml` pack
A full list of breaking changes can be found at [SPARK-14810](https://issues.apache.org/jira/browse/SPARK-14810).
### Deprecations and changes of behavior
+{:.no_toc}
**Deprecations**
@@ -221,7 +289,7 @@ Changes of behavior in the `spark.mllib` and `spark.ml` packages include:
`QuantileDiscretizer` now uses `spark.sql.DataFrameStatFunctions.approxQuantile` to find splits (previously used custom sampling logic).
The output buckets will differ for same input data and params.
-## From 1.5 to 1.6
+## Upgrading from MLlib 1.5 to 1.6
There are no breaking API changes in the `spark.mllib` or `spark.ml` packages, but there are
deprecations and changes of behavior.
@@ -248,7 +316,7 @@ Changes of behavior:
tokenizing. Now, it converts to lowercase by default, with an option not to. This matches the
behavior of the simpler `Tokenizer` transformer.
-## From 1.4 to 1.5
+## Upgrading from MLlib 1.4 to 1.5
In the `spark.mllib` package, there are no breaking API changes but several behavior changes:
@@ -267,7 +335,7 @@ In the `spark.ml` package, there exists one breaking API change and one behavior
* [SPARK-10097](https://issues.apache.org/jira/browse/SPARK-10097): `Evaluator.isLargerBetter` is
added to indicate metric ordering. Metrics like RMSE no longer flip signs as in 1.4.
-## From 1.3 to 1.4
+## Upgrading from MLlib 1.3 to 1.4
In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs:
@@ -286,7 +354,7 @@ Since the `spark.ml` API was an alpha component in Spark 1.3, we do not list all
However, since 1.4 `spark.ml` is no longer an alpha component, we will provide details on any API
changes for future releases.
-## From 1.2 to 1.3
+## Upgrading from MLlib 1.2 to 1.3
In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental.
@@ -313,7 +381,7 @@ Other changes were in `LogisticRegression`:
* The `scoreCol` output column (with default value "score") was renamed to be `probabilityCol` (with default value "probability"). The type was originally `Double` (for the probability of class 1.0), but it is now `Vector` (for the probability of each class, to support multiclass classification in the future).
* In Spark 1.2, `LogisticRegressionModel` did not include an intercept. In Spark 1.3, it includes an intercept; however, it will always be 0.0 since it uses the default settings for [spark.mllib.LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS). The option to use an intercept will be added in the future.
-## From 1.1 to 1.2
+## Upgrading from MLlib 1.1 to 1.2
The only API changes in MLlib v1.2 are in
[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree),
@@ -339,7 +407,7 @@ The tree `Node` now includes more information, including the probability of the
Examples in the Spark distribution and examples in the
[Decision Trees Guide](mllib-decision-tree.html#examples) have been updated accordingly.
-## From 1.0 to 1.1
+## Upgrading from MLlib 1.0 to 1.1
The only API changes in MLlib v1.1 are in
[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree),
@@ -365,7 +433,7 @@ simple `String` types.
Examples of the new recommended `trainClassifier` and `trainRegressor` are given in the
[Decision Trees Guide](mllib-decision-tree.html#examples).
-## From 0.9 to 1.0
+## Upgrading from MLlib 0.9 to 1.0
In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few
breaking changes. If your data is sparse, please store it in a sparse format instead of dense to
diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md
index f931fa32ea541..fb2883de6810a 100644
--- a/docs/mllib-evaluation-metrics.md
+++ b/docs/mllib-evaluation-metrics.md
@@ -577,31 +577,3 @@ variable from a number of independent variables.
-
-**Examples**
-
-
-The following code snippets illustrate how to load a sample dataset, train a linear regression algorithm on the data,
-and evaluate the performance of the algorithm by several regression metrics.
-
-
-Refer to the [`RegressionMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics) for details on the API.
-
-{% include_example scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala %}
-
-
-
-
-Refer to the [`RegressionMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html) for details on the API.
-
-{% include_example java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java %}
-
-
-
-
-Refer to the [`RegressionMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics) for more details on the API.
-
-{% include_example python/mllib/regression_metrics_example.py %}
-
-
-
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md
index b7f8ae9d07b0a..33a223ad486af 100644
--- a/docs/mllib-feature-extraction.md
+++ b/docs/mllib-feature-extraction.md
@@ -348,17 +348,3 @@ Refer to the [`ElementwiseProduct` Python docs](api/python/pyspark.mllib.html#py
A feature transformer that projects vectors to a low-dimensional space using PCA.
Details you can read at [dimensionality reduction](mllib-dimensionality-reduction.html).
-
-### Example
-
-The following code demonstrates how to compute principal components on a `Vector`
-and use them to project the vectors into a low-dimensional space while keeping associated labels
-for calculation a [Linear Regression](mllib-linear-methods.html)
-
-
-
-Refer to the [`PCA` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.PCA) for details on the API.
-
-{% include_example scala/org/apache/spark/examples/mllib/PCAExample.scala %}
-
-
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md
index 2d3ec4ca24443..801876dbffa79 100644
--- a/docs/mllib-linear-methods.md
+++ b/docs/mllib-linear-methods.md
@@ -360,57 +360,6 @@ regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) u
regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is
known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_error).
-**Examples**
-
-
-
-
-The following example demonstrates how to load training data, parse it as an RDD of LabeledPoint.
-The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
-values. We compute the mean squared error at the end to evaluate
-[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).
-
-Refer to the [`LinearRegressionWithSGD` Scala docs](api/scala/index.html#org.apache.spark.mllib.regression.LinearRegressionWithSGD) and [`LinearRegressionModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.regression.LinearRegressionModel) for details on the API.
-
-{% include_example scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala %}
-
-[`RidgeRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD)
-and [`LassoWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.LassoWithSGD) can be used in a similar fashion as `LinearRegressionWithSGD`.
-
-
-
-
-All of MLlib's methods use Java-friendly types, so you can import and call them there the same
-way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the
-Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by
-calling `.rdd()` on your `JavaRDD` object. The corresponding Java example to
-the Scala snippet provided, is presented below:
-
-Refer to the [`LinearRegressionWithSGD` Java docs](api/java/org/apache/spark/mllib/regression/LinearRegressionWithSGD.html) and [`LinearRegressionModel` Java docs](api/java/org/apache/spark/mllib/regression/LinearRegressionModel.html) for details on the API.
-
-{% include_example java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java %}
-
-
-
-The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint.
-The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
-values. We compute the mean squared error at the end to evaluate
-[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).
-
-Note that the Python API does not yet support model save/load but will in the future.
-
-Refer to the [`LinearRegressionWithSGD` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.regression.LinearRegressionWithSGD) and [`LinearRegressionModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.regression.LinearRegressionModel) for more details on the API.
-
-{% include_example python/mllib/linear_regression_with_sgd_example.py %}
-
-
-
-In order to run the above application, follow the instructions
-provided in the [Self-Contained Applications](quick-start.html#self-contained-applications)
-section of the Spark
-quick-start guide. Be sure to also include *spark-mllib* to your build file as
-a dependency.
-
### Streaming linear regression
When data arrive in a streaming fashion, it is useful to fit regression models online,
diff --git a/docs/monitoring.md b/docs/monitoring.md
index 8c81916d4f7d0..a45a41dc78cc3 100644
--- a/docs/monitoring.md
+++ b/docs/monitoring.md
@@ -1059,6 +1059,11 @@ when running in local mode.
- hiveClientCalls.count
- sourceCodeSize (histogram)
+- namespace=
+ - Optional namespace(s). Metrics in this namespace are defined by user-supplied code, and
+ configured using the Spark executor plugin infrastructure.
+ See also the configuration parameter `spark.executor.plugins`
+
### Source = JVM Source
Notes:
- Activate this source by setting the relevant `metrics.properties` file entry or the
diff --git a/docs/pyspark-migration-guide.md b/docs/pyspark-migration-guide.md
new file mode 100644
index 0000000000000..889941c37bf43
--- /dev/null
+++ b/docs/pyspark-migration-guide.md
@@ -0,0 +1,120 @@
+---
+layout: global
+title: "Migration Guide: PySpark (Python on Spark)"
+displayTitle: "Migration Guide: PySpark (Python on Spark)"
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+* Table of contents
+{:toc}
+
+Note that this migration guide describes the items specific to PySpark.
+Many items of SQL migration can be applied when migrating PySpark to higher versions.
+Please refer [Migration Guide: SQL, Datasets and DataFrame](sql-migration-guide.html).
+
+## Upgrading from PySpark 2.4 to 3.0
+
+ - Since Spark 3.0, PySpark requires a Pandas version of 0.23.2 or higher to use Pandas related functionality, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc.
+
+ - Since Spark 3.0, PySpark requires a PyArrow version of 0.12.1 or higher to use PyArrow related functionality, such as `pandas_udf`, `toPandas` and `createDataFrame` with "spark.sql.execution.arrow.enabled=true", etc.
+
+ - In PySpark, when creating a `SparkSession` with `SparkSession.builder.getOrCreate()`, if there is an existing `SparkContext`, the builder was trying to update the `SparkConf` of the existing `SparkContext` with configurations specified to the builder, but the `SparkContext` is shared by all `SparkSession`s, so we should not update them. Since 3.0, the builder comes to not update the configurations. This is the same behavior as Java/Scala API in 2.3 and above. If you want to update them, you need to update them prior to creating a `SparkSession`.
+
+ - In PySpark, when Arrow optimization is enabled, if Arrow version is higher than 0.11.0, Arrow can perform safe type conversion when converting Pandas.Series to Arrow array during serialization. Arrow will raise errors when detecting unsafe type conversion like overflow. Setting `spark.sql.execution.pandas.arrowSafeTypeConversion` to true can enable it. The default setting is false. PySpark's behavior for Arrow versions is illustrated in the table below:
+
+
+
+ PyArrow version
+
+
+ Integer Overflow
+
+
+ Floating Point Truncation
+
+
+
+
+ version < 0.11.0
+
+
+ Raise error
+
+
+ Silently allows
+
+
+
+
+ version > 0.11.0, arrowSafeTypeConversion=false
+
+
+ Silent overflow
+
+
+ Silently allows
+
+
+
+
+ version > 0.11.0, arrowSafeTypeConversion=true
+
+
+ Raise error
+
+
+ Raise error
+
+
+
+
+ - Since Spark 3.0, `createDataFrame(..., verifySchema=True)` validates `LongType` as well in PySpark. Previously, `LongType` was not verified and resulted in `None` in case the value overflows. To restore this behavior, `verifySchema` can be set to `False` to disable the validation.
+
+## Upgrading from PySpark 2.3 to 2.4
+
+ - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`.
+
+## Upgrading from PySpark 2.3.0 to 2.3.1 and above
+
+ - As of version 2.3.1 Arrow functionality, including `pandas_udf` and `toPandas()`/`createDataFrame()` with `spark.sql.execution.arrow.enabled` set to `True`, has been marked as experimental. These are still evolving and not currently recommended for use in production.
+
+## Upgrading from PySpark 2.2 to 2.3
+
+ - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc.
+
+ - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details.
+
+ - In PySpark, `na.fill()` or `fillna` also accepts boolean and replaces nulls with booleans. In prior Spark versions, PySpark just ignores it and returns the original Dataset/DataFrame.
+
+ - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error-prone.
+
+## Upgrading from PySpark 1.4 to 1.5
+
+ - Resolution of strings to columns in Python now supports using dots (`.`) to qualify the column or
+ access nested values. For example `df['table.column.nestedField']`. However, this means that if
+ your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``).
+
+ - DataFrame.withColumn method in PySpark supports adding a new column or replacing existing columns of the same name.
+
+
+## Upgrading from PySpark 1.0-1.2 to 1.3
+
+#### Python DataTypes No Longer Singletons
+{:.no_toc}
+
+When using DataTypes in Python you will need to construct them (i.e. `StringType()`) instead of
+referencing a singleton.
diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md
index 2d4e5cd65f497..4ef738ed9ef6e 100644
--- a/docs/running-on-kubernetes.md
+++ b/docs/running-on-kubernetes.md
@@ -1266,3 +1266,14 @@ The following affect the driver and executor containers. All other containers in
+
+### Resource Allocation and Configuration Overview
+
+Please make sure to have read the Custom Resource Scheduling and Configuration Overview section on the [configuration page](configuration.html). This section only talks about the Kubernetes specific aspects of resource scheduling.
+
+The user is responsible to properly configuring the Kubernetes cluster to have the resources available and ideally isolate each resource per container so that a resource is not shared between multiple containers. If the resource is not isolated the user is responsible for writing a discovery script so that the resource is not shared between containers. See the Kubernetes documentation for specifics on configuring Kubernetes with [custom resources](https://kubernetes.io/docs/concepts/extend-kubernetes/compute-storage-net/device-plugins/).
+
+Spark automatically handles translating the Spark configs spark.{driver/executor}.resource.{resourceType}
into the kubernetes configs as long as the Kubernetes resource type follows the Kubernetes device plugin format of `vendor-domain/resourcetype`. The user must specify the vendor using the spark.{driver/executor}.resource.{resourceType}.vendor
config. The user does not need to explicitly add anything if you are using Pod templates. For reference and an example, you can see the Kubernetes documentation for scheduling [GPUs](https://kubernetes.io/docs/tasks/manage-gpus/scheduling-gpus/). Spark only supports setting the resource limits.
+
+Kubernetes does not tell Spark the addresses of the resources allocated to each container. For that reason, the user must specify a discovery script that gets run by the executor on startup to discover what resources are available to that executor. You can find an example scripts in `examples/src/main/scripts/getGpusResources.sh`. The script must have execute permissions set and the user should setup permissions to not allow malicious users to modify it. The script should write to STDOUT a JSON string in the format of the ResourceInformation class. This has the resource name and an array of resource addresses available to just that executor.
+
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index d3d049e6fef70..418db41216cdb 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -542,6 +542,20 @@ For example, suppose you would like to point log url link to Job History Server
NOTE: you need to replace `` and `` with actual value.
+# Resource Allocation and Configuration Overview
+
+Please make sure to have read the Custom Resource Scheduling and Configuration Overview section on the [configuration page](configuration.html). This section only talks about the YARN specific aspects of resource scheduling.
+
+YARN needs to be configured to support any resources the user wants to use with Spark. Resource scheduling on YARN was added in YARN 3.1.0. See the YARN documentation for more information on configuring resources and properly setting up isolation. Ideally the resources are setup isolated so that an executor can only see the resources it was allocated. If you do not have isolation enabled, the user is responsible for creating a discovery script that ensures the resource is not shared between executors.
+
+YARN currently supports any user defined resource type but has built in types for GPU (yarn.io/gpu
) and FPGA (yarn.io/fpga
). For that reason, if you are using either of those resources, Spark can translate your request for spark resources into YARN resources and you only have to specify the spark.{driver/executor}.resource.
configs. If you are using a resource other then FPGA or GPU, the user is responsible for specifying the configs for both YARN (spark.yarn.{driver/executor}.resource.
) and Spark (spark.{driver/executor}.resource.
).
+
+For example, the user wants to request 2 GPUs for each executor. The user can just specify spark.executor.resource.gpu.amount=2
and Spark will handle requesting yarn.io/gpu
resource type from YARN.
+
+If the user has a user defined YARN resource, lets call it `acceleratorX` then the user must specify spark.yarn.executor.resource.acceleratorX.amount=2
and spark.executor.resource.acceleratorX.amount=2
.
+
+YARN does not tell Spark the addresses of the resources allocated to each container. For that reason, the user must specify a discovery script that gets run by the executor on startup to discover what resources are available to that executor. You can find an example scripts in `examples/src/main/scripts/getGpusResources.sh`. The script must have execute permissions set and the user should setup permissions to not allow malicious users to modify it. The script should write to STDOUT a JSON string in the format of the ResourceInformation class. This has the resource name and an array of resource addresses available to just that executor.
+
# Important notes
- Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured.
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index bc77469b6664f..1264951a2f270 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -340,6 +340,18 @@ SPARK_WORKER_OPTS supports the following system properties:
+# Resource Allocation and Configuration Overview
+
+Please make sure to have read the Custom Resource Scheduling and Configuration Overview section on the [configuration page](configuration.html). This section only talks about the Spark Standalone specific aspects of resource scheduling.
+
+Spark Standalone has 2 parts, the first is configuring the resources for the Worker, the second is the resource allocation for a specific application.
+
+The user must configure the Workers to have a set of resources available so that it can assign them out to Executors. The spark.worker.resource.{resourceName}.amount
is used to control the amount of each resource the worker has allocated. The user must also specify either spark.worker.resourcesFile
or spark.worker.resource.{resourceName}.discoveryScript
to specify how the Worker discovers the resources its assigned. See the descriptions above for each of those to see which method works best for your setup. Please take note of spark.resources.coordinate.enable
as it indicates whether Spark should handle coordinating resources or if the user has made sure each Worker has separate resources. Also note that if using the resources coordination spark.resources.dir
can be used to specify the directory used to do that coordination.
+
+The second part is running an application on Spark Standalone. The only special case from the standard Spark resource configs is when you are running the Driver in client mode. For a Driver in client mode, the user can specify the resources it uses via spark.driver.resourcesfile
or spark.driver.resources.{resourceName}.discoveryScript
. If the Driver is running on the same host as other Drivers or Workers there are 2 ways to make sure the they don't use the same resources. The user can either configure spark.resources.coordinate.enable
on and give all the Driver/Workers the same set or resources and Spark will handle make sure each Driver/Worker has separate resources, or the user can make sure the resources file or discovery script only returns resources the do not conflict with other Drivers or Workers running on the same node.
+
+Note, the user does not need to specify a discovery script when submitting an application as the Worker will start each Executor with the resources it allocates to it.
+
# Connecting an Application to the Cluster
To run an application on the Spark cluster, simply pass the `spark://IP:PORT` URL of the master as to the [`SparkContext`
@@ -420,7 +432,7 @@ In addition, detailed log output for each job is also written to the work direct
# Running Alongside Hadoop
-You can run Spark alongside your existing Hadoop cluster by just launching it as a separate service on the same machines. To access Hadoop data from Spark, just use a hdfs:// URL (typically `hdfs://:9000/path`, but you can find the right URL on your Hadoop Namenode's web UI). Alternatively, you can set up a separate cluster for Spark, and still have it access HDFS over the network; this will be slower than disk-local access, but may not be a concern if you are still running in the same local area network (e.g. you place a few Spark machines on each rack that you have Hadoop on).
+You can run Spark alongside your existing Hadoop cluster by just launching it as a separate service on the same machines. To access Hadoop data from Spark, just use an hdfs:// URL (typically `hdfs://:9000/path`, but you can find the right URL on your Hadoop Namenode's web UI). Alternatively, you can set up a separate cluster for Spark, and still have it access HDFS over the network; this will be slower than disk-local access, but may not be a concern if you are still running in the same local area network (e.g. you place a few Spark machines on each rack that you have Hadoop on).
# Configuring Ports for Network Security
diff --git a/docs/sparkr-migration-guide.md b/docs/sparkr-migration-guide.md
new file mode 100644
index 0000000000000..6fbc4c03aefc1
--- /dev/null
+++ b/docs/sparkr-migration-guide.md
@@ -0,0 +1,77 @@
+---
+layout: global
+title: "Migration Guide: SparkR (R on Spark)"
+displayTitle: "Migration Guide: SparkR (R on Spark)"
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+* Table of contents
+{:toc}
+
+Note that this migration guide describes the items specific to SparkR.
+Many items of SQL migration can be applied when migrating SparkR to higher versions.
+Please refer [Migration Guide: SQL, Datasets and DataFrame](sql-migration-guide.html).
+
+## Upgrading from SparkR 2.4 to 3.0
+
+ - The deprecated methods `sparkR.init`, `sparkRSQL.init`, `sparkRHive.init` have been removed. Use `sparkR.session` instead.
+ - The deprecated methods `parquetFile`, `saveAsParquetFile`, `jsonFile`, `registerTempTable`, `createExternalTable`, and `dropTempTable` have been removed. Use `read.parquet`, `write.parquet`, `read.json`, `createOrReplaceTempView`, `createTable`, `dropTempView`, `union` instead.
+
+## Upgrading from SparkR 2.3 to 2.4
+
+ - Previously, we don't check the validity of the size of the last layer in `spark.mlp`. For example, if the training data only has two labels, a `layers` param like `c(1, 3)` doesn't cause an error previously, now it does.
+
+## Upgrading from SparkR 2.3 to 2.3.1 and above
+
+ - In SparkR 2.3.0 and earlier, the `start` parameter of `substr` method was wrongly subtracted by one and considered as 0-based. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. In version 2.3.1 and later, it has been fixed so the `start` parameter of `substr` method is now 1-based. As an example, `substr(lit('abcdef'), 2, 4))` would result to `abc` in SparkR 2.3.0, and the result would be `bcd` in SparkR 2.3.1.
+
+## Upgrading from SparkR 2.2 to 2.3
+
+ - The `stringsAsFactors` parameter was previously ignored with `collect`, for example, in `collect(createDataFrame(iris), stringsAsFactors = TRUE))`. It has been corrected.
+ - For `summary`, option for statistics to compute has been added. Its output is changed from that from `describe`.
+ - A warning can be raised if versions of SparkR package and the Spark JVM do not match.
+
+## Upgrading from SparkR 2.1 to 2.2
+
+ - A `numPartitions` parameter has been added to `createDataFrame` and `as.DataFrame`. When splitting the data, the partition position calculation has been made to match the one in Scala.
+ - The method `createExternalTable` has been deprecated to be replaced by `createTable`. Either methods can be called to create external or managed table. Additional catalog methods have also been added.
+ - By default, derby.log is now saved to `tempdir()`. This will be created when instantiating the SparkSession with `enableHiveSupport` set to `TRUE`.
+ - `spark.lda` was not setting the optimizer correctly. It has been corrected.
+ - Several model summary outputs are updated to have `coefficients` as `matrix`. This includes `spark.logit`, `spark.kmeans`, `spark.glm`. Model summary outputs for `spark.gaussianMixture` have added log-likelihood as `loglik`.
+
+## Upgrading from SparkR 2.0 to 3.1
+
+ - `join` no longer performs Cartesian Product by default, use `crossJoin` instead.
+
+
+## Upgrading from SparkR 1.6 to 2.0
+
+ - The method `table` has been removed and replaced by `tableToDF`.
+ - The class `DataFrame` has been renamed to `SparkDataFrame` to avoid name conflicts.
+ - Spark's `SQLContext` and `HiveContext` have been deprecated to be replaced by `SparkSession`. Instead of `sparkR.init()`, call `sparkR.session()` in its place to instantiate the SparkSession. Once that is done, that currently active SparkSession will be used for SparkDataFrame operations.
+ - The parameter `sparkExecutorEnv` is not supported by `sparkR.session`. To set environment for the executors, set Spark config properties with the prefix "spark.executorEnv.VAR_NAME", for example, "spark.executorEnv.PATH"
+ - The `sqlContext` parameter is no longer required for these functions: `createDataFrame`, `as.DataFrame`, `read.json`, `jsonFile`, `read.parquet`, `parquetFile`, `read.text`, `sql`, `tables`, `tableNames`, `cacheTable`, `uncacheTable`, `clearCache`, `dropTempTable`, `read.df`, `loadDF`, `createExternalTable`.
+ - The method `registerTempTable` has been deprecated to be replaced by `createOrReplaceTempView`.
+ - The method `dropTempTable` has been deprecated to be replaced by `dropTempView`.
+ - The `sc` SparkContext parameter is no longer required for these functions: `setJobGroup`, `clearJobGroup`, `cancelJobGroup`
+
+## Upgrading from SparkR 1.5 to 1.6
+
+ - Before Spark 1.6.0, the default mode for writes was `append`. It was changed in Spark 1.6.0 to `error` to match the Scala API.
+ - SparkSQL converts `NA` in R to `null` and vice-versa.
+ - Since 1.6.1, withColumn method in SparkR supports adding a new column to or replacing existing columns
+ of the same name of a DataFrame.
diff --git a/docs/sparkr.md b/docs/sparkr.md
index 7431d025aa629..24fa3b4feac19 100644
--- a/docs/sparkr.md
+++ b/docs/sparkr.md
@@ -663,13 +663,20 @@ Apache Arrow is an in-memory columnar data format that is used in Spark to effic
## Ensure Arrow Installed
-Currently, Arrow R library is not on CRAN yet [ARROW-3204](https://issues.apache.org/jira/browse/ARROW-3204). Therefore, it should be installed directly from Github. You can use `remotes::install_github` as below.
+Arrow R library is available on CRAN as of [ARROW-3204](https://issues.apache.org/jira/browse/ARROW-3204). It can be installed as below.
+
+```bash
+Rscript -e 'install.packages("arrow", repos="https://cloud.r-project.org/")'
+```
+
+If you need to install old versions, it should be installed directly from Github. You can use `remotes::install_github` as below.
```bash
Rscript -e 'remotes::install_github("apache/arrow@apache-arrow-0.12.1", subdir = "r")'
```
-`apache-arrow-0.12.1` is a version tag that can be checked in [Arrow at Github](https://github.com/apache/arrow/releases). You must ensure that Arrow R package is installed and available on all cluster nodes. The current supported version is 0.12.1.
+`apache-arrow-0.12.1` is a version tag that can be checked in [Arrow at Github](https://github.com/apache/arrow/releases). You must ensure that Arrow R package is installed and available on all cluster nodes.
+The current supported minimum version is 0.12.1; however, this might change between the minor releases since Arrow optimization in SparkR is experimental.
## Enabling for Conversion to/from R DataFrame, `dapply` and `gapply`
@@ -748,49 +755,5 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma
# Migration Guide
-## Upgrading From SparkR 1.5.x to 1.6.x
-
- - Before Spark 1.6.0, the default mode for writes was `append`. It was changed in Spark 1.6.0 to `error` to match the Scala API.
- - SparkSQL converts `NA` in R to `null` and vice-versa.
-
-## Upgrading From SparkR 1.6.x to 2.0
-
- - The method `table` has been removed and replaced by `tableToDF`.
- - The class `DataFrame` has been renamed to `SparkDataFrame` to avoid name conflicts.
- - Spark's `SQLContext` and `HiveContext` have been deprecated to be replaced by `SparkSession`. Instead of `sparkR.init()`, call `sparkR.session()` in its place to instantiate the SparkSession. Once that is done, that currently active SparkSession will be used for SparkDataFrame operations.
- - The parameter `sparkExecutorEnv` is not supported by `sparkR.session`. To set environment for the executors, set Spark config properties with the prefix "spark.executorEnv.VAR_NAME", for example, "spark.executorEnv.PATH"
- - The `sqlContext` parameter is no longer required for these functions: `createDataFrame`, `as.DataFrame`, `read.json`, `jsonFile`, `read.parquet`, `parquetFile`, `read.text`, `sql`, `tables`, `tableNames`, `cacheTable`, `uncacheTable`, `clearCache`, `dropTempTable`, `read.df`, `loadDF`, `createExternalTable`.
- - The method `registerTempTable` has been deprecated to be replaced by `createOrReplaceTempView`.
- - The method `dropTempTable` has been deprecated to be replaced by `dropTempView`.
- - The `sc` SparkContext parameter is no longer required for these functions: `setJobGroup`, `clearJobGroup`, `cancelJobGroup`
-
-## Upgrading to SparkR 2.1.0
-
- - `join` no longer performs Cartesian Product by default, use `crossJoin` instead.
-
-## Upgrading to SparkR 2.2.0
-
- - A `numPartitions` parameter has been added to `createDataFrame` and `as.DataFrame`. When splitting the data, the partition position calculation has been made to match the one in Scala.
- - The method `createExternalTable` has been deprecated to be replaced by `createTable`. Either methods can be called to create external or managed table. Additional catalog methods have also been added.
- - By default, derby.log is now saved to `tempdir()`. This will be created when instantiating the SparkSession with `enableHiveSupport` set to `TRUE`.
- - `spark.lda` was not setting the optimizer correctly. It has been corrected.
- - Several model summary outputs are updated to have `coefficients` as `matrix`. This includes `spark.logit`, `spark.kmeans`, `spark.glm`. Model summary outputs for `spark.gaussianMixture` have added log-likelihood as `loglik`.
-
-## Upgrading to SparkR 2.3.0
-
- - The `stringsAsFactors` parameter was previously ignored with `collect`, for example, in `collect(createDataFrame(iris), stringsAsFactors = TRUE))`. It has been corrected.
- - For `summary`, option for statistics to compute has been added. Its output is changed from that from `describe`.
- - A warning can be raised if versions of SparkR package and the Spark JVM do not match.
-
-## Upgrading to SparkR 2.3.1 and above
-
- - In SparkR 2.3.0 and earlier, the `start` parameter of `substr` method was wrongly subtracted by one and considered as 0-based. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. In version 2.3.1 and later, it has been fixed so the `start` parameter of `substr` method is now 1-based. As an example, `substr(lit('abcdef'), 2, 4))` would result to `abc` in SparkR 2.3.0, and the result would be `bcd` in SparkR 2.3.1.
-
-## Upgrading to SparkR 2.4.0
-
- - Previously, we don't check the validity of the size of the last layer in `spark.mlp`. For example, if the training data only has two labels, a `layers` param like `c(1, 3)` doesn't cause an error previously, now it does.
-
-## Upgrading to SparkR 3.0.0
+The migration guide is now archived [on this page](sparkr-migration-guide.html).
- - The deprecated methods `sparkR.init`, `sparkRSQL.init`, `sparkRHive.init` have been removed. Use `sparkR.session` instead.
- - The deprecated methods `parquetFile`, `saveAsParquetFile`, `jsonFile`, `registerTempTable`, `createExternalTable`, and `dropTempTable` have been removed. Use `read.parquet`, `write.parquet`, `read.json`, `createOrReplaceTempView`, `createTable`, `dropTempView`, `union` instead.
diff --git a/docs/sql-keywords.md b/docs/sql-keywords.md
index 08be6b62a88e7..7a0e3efee8ffa 100644
--- a/docs/sql-keywords.md
+++ b/docs/sql-keywords.md
@@ -19,15 +19,15 @@ license: |
limitations under the License.
---
-When `spark.sql.parser.ansi.enabled` is true, Spark SQL has two kinds of keywords:
+When `spark.sql.ansi.enabled` is true, Spark SQL has two kinds of keywords:
* Reserved keywords: Keywords that are reserved and can't be used as identifiers for table, view, column, function, alias, etc.
* Non-reserved keywords: Keywords that have a special meaning only in particular contexts and can be used as identifiers in other contexts. For example, `SELECT 1 WEEK` is an interval literal, but WEEK can be used as identifiers in other places.
-When `spark.sql.parser.ansi.enabled` is false, Spark SQL has two kinds of keywords:
-* Non-reserved keywords: Same definition as the one when `spark.sql.parser.ansi.enabled=true`.
+When `spark.sql.ansi.enabled` is false, Spark SQL has two kinds of keywords:
+* Non-reserved keywords: Same definition as the one when `spark.sql.ansi.enabled=true`.
* Strict-non-reserved keywords: A strict version of non-reserved keywords, which can not be used as table alias.
-By default `spark.sql.parser.ansi.enabled` is false.
+By default `spark.sql.ansi.enabled` is false.
Below is a list of all the keywords in Spark SQL.
@@ -179,6 +179,8 @@ Below is a list of all the keywords in Spark SQL.
MONTH reserved non-reserved reserved
MONTHS non-reserved non-reserved non-reserved
MSCK non-reserved non-reserved non-reserved
+ NAMESPACE non-reserved non-reserved non-reserved
+ NAMESPACES non-reserved non-reserved non-reserved
NATURAL reserved strict-non-reserved reserved
NO non-reserved non-reserved reserved
NOT reserved non-reserved reserved
@@ -279,6 +281,7 @@ Below is a list of all the keywords in Spark SQL.
UNKNOWN reserved non-reserved reserved
UNLOCK non-reserved non-reserved non-reserved
UNSET non-reserved non-reserved non-reserved
+ UPDATE non-reserved non-reserved reserved
USE non-reserved non-reserved non-reserved
USER reserved non-reserved reserved
USING reserved strict-non-reserved reserved
diff --git a/docs/sql-migration-guide-hive-compatibility.md b/docs/sql-migration-guide-hive-compatibility.md
deleted file mode 100644
index d4b4fdf19d926..0000000000000
--- a/docs/sql-migration-guide-hive-compatibility.md
+++ /dev/null
@@ -1,167 +0,0 @@
----
-layout: global
-title: Compatibility with Apache Hive
-displayTitle: Compatibility with Apache Hive
-license: |
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
----
-
-* Table of contents
-{:toc}
-
-Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs.
-Currently, Hive SerDes and UDFs are based on Hive 1.2.1,
-and Spark SQL can be connected to different versions of Hive Metastore
-(from 0.12.0 to 2.3.6 and 3.0.0 to 3.1.2. Also see [Interacting with Different Versions of Hive Metastore](sql-data-sources-hive-tables.html#interacting-with-different-versions-of-hive-metastore)).
-
-#### Deploying in Existing Hive Warehouses
-
-The Spark SQL Thrift JDBC server is designed to be "out of the box" compatible with existing Hive
-installations. You do not need to modify your existing Hive Metastore or change the data placement
-or partitioning of your tables.
-
-### Supported Hive Features
-
-Spark SQL supports the vast majority of Hive features, such as:
-
-* Hive query statements, including:
- * `SELECT`
- * `GROUP BY`
- * `ORDER BY`
- * `CLUSTER BY`
- * `SORT BY`
-* All Hive operators, including:
- * Relational operators (`=`, `⇔`, `==`, `<>`, `<`, `>`, `>=`, `<=`, etc)
- * Arithmetic operators (`+`, `-`, `*`, `/`, `%`, etc)
- * Logical operators (`AND`, `&&`, `OR`, `||`, etc)
- * Complex type constructors
- * Mathematical functions (`sign`, `ln`, `cos`, etc)
- * String functions (`instr`, `length`, `printf`, etc)
-* User defined functions (UDF)
-* User defined aggregation functions (UDAF)
-* User defined serialization formats (SerDes)
-* Window functions
-* Joins
- * `JOIN`
- * `{LEFT|RIGHT|FULL} OUTER JOIN`
- * `LEFT SEMI JOIN`
- * `CROSS JOIN`
-* Unions
-* Sub-queries
- * `SELECT col FROM ( SELECT a + b AS col from t1) t2`
-* Sampling
-* Explain
-* Partitioned tables including dynamic partition insertion
-* View
- * If column aliases are not specified in view definition queries, both Spark and Hive will
- generate alias names, but in different ways. In order for Spark to be able to read views created
- by Hive, users should explicitly specify column aliases in view definition queries. As an
- example, Spark cannot read `v1` created as below by Hive.
-
- ```
- CREATE VIEW v1 AS SELECT * FROM (SELECT c + 1 FROM (SELECT 1 c) t1) t2;
- ```
-
- Instead, you should create `v1` as below with column aliases explicitly specified.
-
- ```
- CREATE VIEW v1 AS SELECT * FROM (SELECT c + 1 AS inc_c FROM (SELECT 1 c) t1) t2;
- ```
-
-* All Hive DDL Functions, including:
- * `CREATE TABLE`
- * `CREATE TABLE AS SELECT`
- * `ALTER TABLE`
-* Most Hive Data types, including:
- * `TINYINT`
- * `SMALLINT`
- * `INT`
- * `BIGINT`
- * `BOOLEAN`
- * `FLOAT`
- * `DOUBLE`
- * `STRING`
- * `BINARY`
- * `TIMESTAMP`
- * `DATE`
- * `ARRAY<>`
- * `MAP<>`
- * `STRUCT<>`
-
-### Unsupported Hive Functionality
-
-Below is a list of Hive features that we don't support yet. Most of these features are rarely used
-in Hive deployments.
-
-**Major Hive Features**
-
-* Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL
- doesn't support buckets yet.
-
-
-**Esoteric Hive Features**
-
-* `UNION` type
-* Unique join
-* Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at
- the moment and only supports populating the sizeInBytes field of the hive metastore.
-
-**Hive Input/Output Formats**
-
-* File format for CLI: For results showing back to the CLI, Spark SQL only supports TextOutputFormat.
-* Hadoop archive
-
-**Hive Optimizations**
-
-A handful of Hive optimizations are not yet included in Spark. Some of these (such as indexes) are
-less important due to Spark SQL's in-memory computational model. Others are slotted for future
-releases of Spark SQL.
-
-* Block-level bitmap indexes and virtual columns (used to build indexes)
-* Automatically determine the number of reducers for joins and groupbys: Currently, in Spark SQL, you
- need to control the degree of parallelism post-shuffle using "`SET spark.sql.shuffle.partitions=[num_tasks];`".
-* Meta-data only query: For queries that can be answered by using only metadata, Spark SQL still
- launches tasks to compute the result.
-* Skew data flag: Spark SQL does not follow the skew data flags in Hive.
-* `STREAMTABLE` hint in join: Spark SQL does not follow the `STREAMTABLE` hint.
-* Merge multiple small files for query results: if the result output contains multiple small files,
- Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS
- metadata. Spark SQL does not support that.
-
-**Hive UDF/UDTF/UDAF**
-
-Not all the APIs of the Hive UDF/UDTF/UDAF are supported by Spark SQL. Below are the unsupported APIs:
-
-* `getRequiredJars` and `getRequiredFiles` (`UDF` and `GenericUDF`) are functions to automatically
- include additional resources required by this UDF.
-* `initialize(StructObjectInspector)` in `GenericUDTF` is not supported yet. Spark SQL currently uses
- a deprecated interface `initialize(ObjectInspector[])` only.
-* `configure` (`GenericUDF`, `GenericUDTF`, and `GenericUDAFEvaluator`) is a function to initialize
- functions with `MapredContext`, which is inapplicable to Spark.
-* `close` (`GenericUDF` and `GenericUDAFEvaluator`) is a function to release associated resources.
- Spark SQL does not call this function when tasks finish.
-* `reset` (`GenericUDAFEvaluator`) is a function to re-initialize aggregation for reusing the same aggregation.
- Spark SQL currently does not support the reuse of aggregation.
-* `getWindowingEvaluator` (`GenericUDAFEvaluator`) is a function to optimize aggregation by evaluating
- an aggregate over a fixed window.
-
-### Incompatible Hive UDF
-
-Below are the scenarios in which Hive and Spark generate different results:
-
-* `SQRT(n)` If n < 0, Hive returns null, Spark SQL returns NaN.
-* `ACOS(n)` If n < -1 or n > 1, Hive returns null, Spark SQL returns NaN.
-* `ASIN(n)` If n < -1 or n > 1, Hive returns null, Spark SQL returns NaN.
diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md
deleted file mode 100644
index cc3ef1e757756..0000000000000
--- a/docs/sql-migration-guide-upgrade.md
+++ /dev/null
@@ -1,829 +0,0 @@
----
-layout: global
-title: Spark SQL Upgrading Guide
-displayTitle: Spark SQL Upgrading Guide
-license: |
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
----
-
-* Table of contents
-{:toc}
-
-## Upgrading From Spark SQL 2.4 to 3.0
- - Since Spark 3.0, configuration `spark.sql.crossJoin.enabled` become internal configuration, and is true by default, so by default spark won't raise exception on sql with implicit cross join.
-
- - Since Spark 3.0, we reversed argument order of the trim function from `TRIM(trimStr, str)` to `TRIM(str, trimStr)` to be compatible with other databases.
-
- - Since Spark 3.0, PySpark requires a Pandas version of 0.23.2 or higher to use Pandas related functionality, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc.
-
- - Since Spark 3.0, PySpark requires a PyArrow version of 0.12.1 or higher to use PyArrow related functionality, such as `pandas_udf`, `toPandas` and `createDataFrame` with "spark.sql.execution.arrow.enabled=true", etc.
-
- - In Spark version 2.4 and earlier, SQL queries such as `FROM ` or `FROM UNION ALL FROM ` are supported by accident. In hive-style `FROM SELECT `, the `SELECT` clause is not negligible. Neither Hive nor Presto support this syntax. Therefore we will treat these queries as invalid since Spark 3.0.
-
- - Since Spark 3.0, the Dataset and DataFrame API `unionAll` is not deprecated any more. It is an alias for `union`.
-
- - In PySpark, when creating a `SparkSession` with `SparkSession.builder.getOrCreate()`, if there is an existing `SparkContext`, the builder was trying to update the `SparkConf` of the existing `SparkContext` with configurations specified to the builder, but the `SparkContext` is shared by all `SparkSession`s, so we should not update them. Since 3.0, the builder comes to not update the configurations. This is the same behavior as Java/Scala API in 2.3 and above. If you want to update them, you need to update them prior to creating a `SparkSession`.
-
- - In Spark version 2.4 and earlier, the parser of JSON data source treats empty strings as null for some data types such as `IntegerType`. For `FloatType` and `DoubleType`, it fails on empty strings and throws exceptions. Since Spark 3.0, we disallow empty strings and will throw exceptions for data types except for `StringType` and `BinaryType`.
-
- - Since Spark 3.0, the `from_json` functions supports two modes - `PERMISSIVE` and `FAILFAST`. The modes can be set via the `mode` option. The default mode became `PERMISSIVE`. In previous versions, behavior of `from_json` did not conform to either `PERMISSIVE` nor `FAILFAST`, especially in processing of malformed JSON records. For example, the JSON string `{"a" 1}` with the schema `a INT` is converted to `null` by previous versions but Spark 3.0 converts it to `Row(null)`.
-
- - The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set.
-
- - In Spark version 2.4 and earlier, users can create map values with map type key via built-in function like `CreateMap`, `MapFromArrays`, etc. Since Spark 3.0, it's not allowed to create map values with map type key with these built-in functions. Users can still read map values with map type key from data source or Java/Scala collections, though they are not very useful.
-
- - In Spark version 2.4 and earlier, `Dataset.groupByKey` results to a grouped dataset with key attribute wrongly named as "value", if the key is non-struct type, e.g. int, string, array, etc. This is counterintuitive and makes the schema of aggregation queries weird. For example, the schema of `ds.groupByKey(...).count()` is `(value, count)`. Since Spark 3.0, we name the grouping attribute to "key". The old behaviour is preserved under a newly added configuration `spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue` with a default value of `false`.
-
- - In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but -0.0 and 0.0 are considered as different values when used in aggregate grouping keys, window partition keys and join keys. Since Spark 3.0, this bug is fixed. For example, `Seq(-0.0, 0.0).toDF("d").groupBy("d").count()` returns `[(0.0, 2)]` in Spark 3.0, and `[(0.0, 1), (-0.0, 1)]` in Spark 2.4 and earlier.
-
- - In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be undefined.
-
- - In Spark version 2.4 and earlier, partition column value is converted as null if it can't be casted to corresponding user provided schema. Since 3.0, partition column value is validated with user provided schema. An exception is thrown if the validation fails. You can disable such validation by setting `spark.sql.sources.validatePartitionColumns` to `false`.
-
- - In Spark version 2.4 and earlier, the `SET` command works without any warnings even if the specified key is for `SparkConf` entries and it has no effect because the command does not update `SparkConf`, but the behavior might confuse users. Since 3.0, the command fails if a `SparkConf` key is used. You can disable such a check by setting `spark.sql.legacy.setCommandRejectsSparkCoreConfs` to `false`.
-
- - In Spark version 2.4 and earlier, CSV datasource converts a malformed CSV string to a row with all `null`s in the PERMISSIVE mode. Since Spark 3.0, the returned row can contain non-`null` fields if some of CSV column values were parsed and converted to desired types successfully.
-
- - In Spark version 2.4 and earlier, JSON datasource and JSON functions like `from_json` convert a bad JSON record to a row with all `null`s in the PERMISSIVE mode when specified schema is `StructType`. Since Spark 3.0, the returned row can contain non-`null` fields if some of JSON column values were parsed and converted to desired types successfully.
-
- - Refreshing a cached table would trigger a table uncache operation and then a table cache (lazily) operation. In Spark version 2.4 and earlier, the cache name and storage level are not preserved before the uncache operation. Therefore, the cache name and storage level could be changed unexpectedly. Since Spark 3.0, cache name and storage level will be first preserved for cache recreation. It helps to maintain a consistent cache behavior upon table refreshing.
-
- - Since Spark 3.0, JSON datasource and JSON function `schema_of_json` infer TimestampType from string values if they match to the pattern defined by the JSON option `timestampFormat`. Set JSON option `inferTimestamp` to `false` to disable such type inferring.
-
- - In PySpark, when Arrow optimization is enabled, if Arrow version is higher than 0.11.0, Arrow can perform safe type conversion when converting Pandas.Series to Arrow array during serialization. Arrow will raise errors when detecting unsafe type conversion like overflow. Setting `spark.sql.execution.pandas.arrowSafeTypeConversion` to true can enable it. The default setting is false. PySpark's behavior for Arrow versions is illustrated in the table below:
-
-
-
- PyArrow version
-
-
- Integer Overflow
-
-
- Floating Point Truncation
-
-
-
-
- version < 0.11.0
-
-
- Raise error
-
-
- Silently allows
-
-
-
-
- version > 0.11.0, arrowSafeTypeConversion=false
-
-
- Silent overflow
-
-
- Silently allows
-
-
-
-
- version > 0.11.0, arrowSafeTypeConversion=true
-
-
- Raise error
-
-
- Raise error
-
-
-
-
- - In Spark version 2.4 and earlier, if `org.apache.spark.sql.functions.udf(Any, DataType)` gets a Scala closure with primitive-type argument, the returned UDF will return null if the input values is null. Since Spark 3.0, the UDF will return the default value of the Java type if the input value is null. For example, `val f = udf((x: Int) => x, IntegerType)`, `f($"x")` will return null in Spark 2.4 and earlier if column `x` is null, and return 0 in Spark 3.0. This behavior change is introduced because Spark 3.0 is built with Scala 2.12 by default.
-
- - Since Spark 3.0, Proleptic Gregorian calendar is used in parsing, formatting, and converting dates and timestamps as well as in extracting sub-components like years, days and etc. Spark 3.0 uses Java 8 API classes from the java.time packages that based on ISO chronology (https://docs.oracle.com/javase/8/docs/api/java/time/chrono/IsoChronology.html). In Spark version 2.4 and earlier, those operations are performed by using the hybrid calendar (Julian + Gregorian, see https://docs.oracle.com/javase/7/docs/api/java/util/GregorianCalendar.html). The changes impact on the results for dates before October 15, 1582 (Gregorian) and affect on the following Spark 3.0 API:
-
- - CSV/JSON datasources use java.time API for parsing and generating CSV/JSON content. In Spark version 2.4 and earlier, java.text.SimpleDateFormat is used for the same purpose with fallbacks to the parsing mechanisms of Spark 2.0 and 1.x. For example, `2018-12-08 10:39:21.123` with the pattern `yyyy-MM-dd'T'HH:mm:ss.SSS` cannot be parsed since Spark 3.0 because the timestamp does not match to the pattern but it can be parsed by earlier Spark versions due to a fallback to `Timestamp.valueOf`. To parse the same timestamp since Spark 3.0, the pattern should be `yyyy-MM-dd HH:mm:ss.SSS`.
-
- - The `unix_timestamp`, `date_format`, `to_unix_timestamp`, `from_unixtime`, `to_date`, `to_timestamp` functions. New implementation supports pattern formats as described here https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html and performs strict checking of its input. For example, the `2015-07-22 10:00:00` timestamp cannot be parse if pattern is `yyyy-MM-dd` because the parser does not consume whole input. Another example is the `31/01/2015 00:00` input cannot be parsed by the `dd/MM/yyyy hh:mm` pattern because `hh` supposes hours in the range `1-12`.
-
- - The `weekofyear`, `weekday`, `dayofweek`, `date_trunc`, `from_utc_timestamp`, `to_utc_timestamp`, and `unix_timestamp` functions use java.time API for calculation week number of year, day number of week as well for conversion from/to TimestampType values in UTC time zone.
-
- - the JDBC options `lowerBound` and `upperBound` are converted to TimestampType/DateType values in the same way as casting strings to TimestampType/DateType values. The conversion is based on Proleptic Gregorian calendar, and time zone defined by the SQL config `spark.sql.session.timeZone`. In Spark version 2.4 and earlier, the conversion is based on the hybrid calendar (Julian + Gregorian) and on default system time zone.
-
- - Formatting of `TIMESTAMP` and `DATE` literals.
-
- - In Spark version 2.4 and earlier, invalid time zone ids are silently ignored and replaced by GMT time zone, for example, in the from_utc_timestamp function. Since Spark 3.0, such time zone ids are rejected, and Spark throws `java.time.DateTimeException`.
-
- - In Spark version 2.4 and earlier, the `current_timestamp` function returns a timestamp with millisecond resolution only. Since Spark 3.0, the function can return the result with microsecond resolution if the underlying clock available on the system offers such resolution.
-
- - In Spark version 2.4 and earlier, when reading a Hive Serde table with Spark native data sources(parquet/orc), Spark will infer the actual file schema and update the table schema in metastore. Since Spark 3.0, Spark doesn't infer the schema anymore. This should not cause any problems to end users, but if it does, please set `spark.sql.hive.caseSensitiveInferenceMode` to `INFER_AND_SAVE`.
-
- - Since Spark 3.0, `TIMESTAMP` literals are converted to strings using the SQL config `spark.sql.session.timeZone`. In Spark version 2.4 and earlier, the conversion uses the default time zone of the Java virtual machine.
-
- - In Spark version 2.4, when a spark session is created via `cloneSession()`, the newly created spark session inherits its configuration from its parent `SparkContext` even though the same configuration may exist with a different value in its parent spark session. Since Spark 3.0, the configurations of a parent `SparkSession` have a higher precedence over the parent `SparkContext`. The old behavior can be restored by setting `spark.sql.legacy.sessionInitWithConfigDefaults` to `true`.
-
- - Since Spark 3.0, parquet logical type `TIMESTAMP_MICROS` is used by default while saving `TIMESTAMP` columns. In Spark version 2.4 and earlier, `TIMESTAMP` columns are saved as `INT96` in parquet files. To set `INT96` to `spark.sql.parquet.outputTimestampType` restores the previous behavior.
-
- - Since Spark 3.0, if `hive.default.fileformat` is not found in `Spark SQL configuration` then it will fallback to hive-site.xml present in the `Hadoop configuration` of `SparkContext`.
-
- - Since Spark 3.0, Spark will cast `String` to `Date/TimeStamp` in binary comparisons with dates/timestamps. The previous behaviour of casting `Date/Timestamp` to `String` can be restored by setting `spark.sql.legacy.typeCoercion.datetimeToString` to `true`.
-
- - Since Spark 3.0, when Avro files are written with user provided schema, the fields will be matched by field names between catalyst schema and avro schema instead of positions.
-
- - Since Spark 3.0, when Avro files are written with user provided non-nullable schema, even the catalyst schema is nullable, Spark is still able to write the files. However, Spark will throw runtime NPE if any of the records contains null.
-
- - Since Spark 3.0, we use a new protocol for fetching shuffle blocks, for external shuffle service users, we need to upgrade the server correspondingly. Otherwise, we'll get the error message `UnsupportedOperationException: Unexpected message: FetchShuffleBlocks`. If it is hard to upgrade the shuffle service right now, you can still use the old protocol by setting `spark.shuffle.useOldFetchProtocol` to `true`.
-
- - Since Spark 3.0, a higher-order function `exists` follows the three-valued boolean logic, i.e., if the `predicate` returns any `null`s and no `true` is obtained, then `exists` will return `null` instead of `false`. For example, `exists(array(1, null, 3), x -> x % 2 == 0)` will be `null`. The previous behaviour can be restored by setting `spark.sql.legacy.arrayExistsFollowsThreeValuedLogic` to `false`.
-
- - Since Spark 3.0, if files or subdirectories disappear during recursive directory listing (i.e. they appear in an intermediate listing but then cannot be read or listed during later phases of the recursive directory listing, due to either concurrent file deletions or object store consistency issues) then the listing will fail with an exception unless `spark.sql.files.ignoreMissingFiles` is `true` (default `false`). In previous versions, these missing files or subdirectories would be ignored. Note that this change of behavior only applies during initial table file listing (or during `REFRESH TABLE`), not during query execution: the net change is that `spark.sql.files.ignoreMissingFiles` is now obeyed during table file listing / query planning, not only at query execution time.
-
- - Since Spark 3.0, `createDataFrame(..., verifySchema=True)` validates `LongType` as well in PySpark. Previously, `LongType` was not verified and resulted in `None` in case the value overflows. To restore this behavior, `verifySchema` can be set to `False` to disable the validation.
-
- - Since Spark 3.0, substitution order of nested WITH clauses is changed and an inner CTE definition takes precedence over an outer. In version 2.4 and earlier, `WITH t AS (SELECT 1), t2 AS (WITH t AS (SELECT 2) SELECT * FROM t) SELECT * FROM t2` returns `1` while in version 3.0 it returns `2`. The previous behaviour can be restored by setting `spark.sql.legacy.ctePrecedence.enabled` to `true`.
-
- - Since Spark 3.0, the `add_months` function does not adjust the resulting date to a last day of month if the original date is a last day of months. For example, `select add_months(DATE'2019-02-28', 1)` results `2019-03-28`. In Spark version 2.4 and earlier, the resulting date is adjusted when the original date is a last day of months. For example, adding a month to `2019-02-28` results in `2019-03-31`.
-
- - Since Spark 3.0, 0-argument Java UDF is executed in the executor side identically with other UDFs. In Spark version 2.4 and earlier, 0-argument Java UDF alone was executed in the driver side, and the result was propagated to executors, which might be more performant in some cases but caused inconsistency with a correctness issue in some cases.
-
- - The result of `java.lang.Math`'s `log`, `log1p`, `exp`, `expm1`, and `pow` may vary across platforms. In Spark 3.0, the result of the equivalent SQL functions (including related SQL functions like `LOG10`) return values consistent with `java.lang.StrictMath`. In virtually all cases this makes no difference in the return value, and the difference is very small, but may not exactly match `java.lang.Math` on x86 platforms in cases like, for example, `log(3.0)`, whose value varies between `Math.log()` and `StrictMath.log()`.
-
- - Since Spark 3.0, Dataset query fails if it contains ambiguous column reference that is caused by self join. A typical example: `val df1 = ...; val df2 = df1.filter(...);`, then `df1.join(df2, df1("a") > df2("a"))` returns an empty result which is quite confusing. This is because Spark cannot resolve Dataset column references that point to tables being self joined, and `df1("a")` is exactly the same as `df2("a")` in Spark. To restore the behavior before Spark 3.0, you can set `spark.sql.analyzer.failAmbiguousSelfJoin` to `false`.
-
- - Since Spark 3.0, `Cast` function processes string literals such as 'Infinity', '+Infinity', '-Infinity', 'NaN', 'Inf', '+Inf', '-Inf' in case insensitive manner when casting the literals to `Double` or `Float` type to ensure greater compatibility with other database systems. This behaviour change is illustrated in the table below:
-
-
-
- Operation
-
-
- Result prior to Spark 3.0
-
-
- Result starting Spark 3.0
-
-
-
-
- CAST('infinity' AS DOUBLE)
- CAST('+infinity' AS DOUBLE)
- CAST('inf' AS DOUBLE)
- CAST('+inf' AS DOUBLE)
-
-
- NULL
-
-
- Double.PositiveInfinity
-
-
-
-
- CAST('-infinity' AS DOUBLE)
- CAST('-inf' AS DOUBLE)
-
-
- NULL
-
-
- Double.NegativeInfinity
-
-
-
-
- CAST('infinity' AS FLOAT)
- CAST('+infinity' AS FLOAT)
- CAST('inf' AS FLOAT)
- CAST('+inf' AS FLOAT)
-
-
- NULL
-
-
- Float.PositiveInfinity
-
-
-
-
- CAST('-infinity' AS FLOAT)
- CAST('-inf' AS FLOAT)
-
-
- NULL
-
-
- Float.NegativeInfinity
-
-
-
-
- CAST('nan' AS DOUBLE)
-
-
- NULL
-
-
- Double.NaN
-
-
-
-
- CAST('nan' AS FLOAT)
-
-
- NULL
-
-
- Float.NaN
-
-
-
-
-## Upgrading from Spark SQL 2.4 to 2.4.1
-
- - The value of `spark.executor.heartbeatInterval`, when specified without units like "30" rather than "30s", was
- inconsistently interpreted as both seconds and milliseconds in Spark 2.4.0 in different parts of the code.
- Unitless values are now consistently interpreted as milliseconds. Applications that set values like "30"
- need to specify a value with units like "30s" now, to avoid being interpreted as milliseconds; otherwise,
- the extremely short interval that results will likely cause applications to fail.
-
- - When turning a Dataset to another Dataset, Spark will up cast the fields in the original Dataset to the type of corresponding fields in the target DataSet. In version 2.4 and earlier, this up cast is not very strict, e.g. `Seq("str").toDS.as[Int]` fails, but `Seq("str").toDS.as[Boolean]` works and throw NPE during execution. In Spark 3.0, the up cast is stricter and turning String into something else is not allowed, i.e. `Seq("str").toDS.as[Boolean]` will fail during analysis.
-
-## Upgrading From Spark SQL 2.3 to 2.4
-
- - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below.
-
-
-
- Query
-
-
- Spark 2.3 or Prior
-
-
- Spark 2.4
-
-
- Remarks
-
-
-
-
- SELECT array_contains(array(1), 1.34D);
-
-
- true
-
-
- false
-
-
- In Spark 2.4, left and right parameters are promoted to array type of double type and double type respectively.
-
-
-
-
- SELECT array_contains(array(1), '1');
-
-
- true
-
-
- AnalysisException
is thrown.
-
-
- Explicit cast can be used in arguments to avoid the exception. In Spark 2.4, AnalysisException
is thrown since integer type can not be promoted to string type in a loss-less manner.
-
-
-
-
- SELECT array_contains(array(1), 'anystring');
-
-
- null
-
-
- AnalysisException
is thrown.
-
-
- Explicit cast can be used in arguments to avoid the exception. In Spark 2.4, AnalysisException
is thrown since integer type can not be promoted to string type in a loss-less manner.
-
-
-
-
- - Since Spark 2.4, when there is a struct field in front of the IN operator before a subquery, the inner query must contain a struct field as well. In previous versions, instead, the fields of the struct were compared to the output of the inner query. Eg. if `a` is a `struct(a string, b int)`, in Spark 2.4 `a in (select (1 as a, 'a' as b) from range(1))` is a valid query, while `a in (select 1, 'a' from range(1))` is not. In previous version it was the opposite.
-
- - In versions 2.2.1+ and 2.3, if `spark.sql.caseSensitive` is set to true, then the `CURRENT_DATE` and `CURRENT_TIMESTAMP` functions incorrectly became case-sensitive and would resolve to columns (unless typed in lower case). In Spark 2.4 this has been fixed and the functions are no longer case-sensitive.
-
- - Since Spark 2.4, Spark will evaluate the set operations referenced in a query by following a precedence rule as per the SQL standard. If the order is not specified by parentheses, set operations are performed from left to right with the exception that all INTERSECT operations are performed before any UNION, EXCEPT or MINUS operations. The old behaviour of giving equal precedence to all the set operations are preserved under a newly added configuration `spark.sql.legacy.setopsPrecedence.enabled` with a default value of `false`. When this property is set to `true`, spark will evaluate the set operators from left to right as they appear in the query given no explicit ordering is enforced by usage of parenthesis.
-
- - Since Spark 2.4, Spark will display table description column Last Access value as UNKNOWN when the value was Jan 01 1970.
-
- - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. ORC files created by native ORC writer cannot be read by some old Apache Hive releases. Use `spark.sql.orc.impl=hive` to create the files shared with Hive 2.1.1 and older.
-
- - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`.
-
- - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe.
-
- - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, a column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``.
-
- - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema.
-
- - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.legacy.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0.
-
- - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0.
-
- - Since Spark 2.4, renaming a managed table to existing location is not allowed. An exception is thrown when attempting to rename a managed table to existing location.
-
- - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception.
-
- - Since Spark 2.4, Spark has enabled non-cascading SQL cache invalidation in addition to the traditional cache invalidation mechanism. The non-cascading cache invalidation mechanism allows users to remove a cache without impacting its dependent caches. This new cache invalidation mechanism is used in scenarios where the data of the cache to be removed is still valid, e.g., calling unpersist() on a Dataset, or dropping a temporary view. This allows users to free up memory and keep the desired caches valid at the same time.
-
- - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files.
-
- - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior.
-
- - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`.
-
- - Since Spark 2.4, File listing for compute statistics is done in parallel by default. This can be disabled by setting `spark.sql.statistics.parallelFileListingInStatsComputation.enabled` to `False`.
-
- - Since Spark 2.4, Metadata files (e.g. Parquet summary files) and temporary files are not counted as data files when calculating table size during Statistics computation.
-
- - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was written as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string.
-
- - Since Spark 2.4, The LOAD DATA command supports wildcard `?` and `*`, which match any one character, and zero or more characters, respectively. Example: `LOAD DATA INPATH '/tmp/folder*/'` or `LOAD DATA INPATH '/tmp/part-?'`. Special Characters like `space` also now work in paths. Example: `LOAD DATA INPATH '/tmp/folder name/'`.
-
- - In Spark version 2.3 and earlier, HAVING without GROUP BY is treated as WHERE. This means, `SELECT 1 FROM range(10) HAVING true` is executed as `SELECT 1 FROM range(10) WHERE true` and returns 10 rows. This violates SQL standard, and has been fixed in Spark 2.4. Since Spark 2.4, HAVING without GROUP BY is treated as a global aggregate, which means `SELECT 1 FROM range(10) HAVING true` will return only one row. To restore the previous behavior, set `spark.sql.legacy.parser.havingWithoutGroupByAsWhere` to `true`.
-
- - In version 2.3 and earlier, when reading from a Parquet data source table, Spark always returns null for any column whose column names in Hive metastore schema and Parquet schema are in different letter cases, no matter whether `spark.sql.caseSensitive` is set to `true` or `false`. Since 2.4, when `spark.sql.caseSensitive` is set to `false`, Spark does case insensitive column name resolution between Hive metastore schema and Parquet schema, so even column names are in different letter cases, Spark returns corresponding column values. An exception is thrown if there is ambiguity, i.e. more than one Parquet column is matched. This change also applies to Parquet Hive tables when `spark.sql.hive.convertMetastoreParquet` is set to `true`.
-
-## Upgrading From Spark SQL 2.3.0 to 2.3.1 and above
-
- - As of version 2.3.1 Arrow functionality, including `pandas_udf` and `toPandas()`/`createDataFrame()` with `spark.sql.execution.arrow.enabled` set to `True`, has been marked as experimental. These are still evolving and not currently recommended for use in production.
-
-## Upgrading From Spark SQL 2.2 to 2.3
-
- - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`.
-
- - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles.
-
- - Since Spark 2.3, the Join/Filter's deterministic predicates that are after the first non-deterministic predicates are also pushed down/through the child operators, if possible. In prior Spark versions, these filters are not eligible for predicate pushdown.
-
- - Partition column inference previously found incorrect common type for different inferred types, for example, previously it ended up with double type as the common type for double type and date type. Now it finds the correct common type for such conflicts. The conflict resolution follows the table below:
-
-
-
- InputA \ InputB
-
-
- NullType
-
-
- IntegerType
-
-
- LongType
-
-
- DecimalType(38,0)*
-
-
- DoubleType
-
-
- DateType
-
-
- TimestampType
-
-
- StringType
-
-
-
-
- NullType
-
- NullType
- IntegerType
- LongType
- DecimalType(38,0)
- DoubleType
- DateType
- TimestampType
- StringType
-
-
-
- IntegerType
-
- IntegerType
- IntegerType
- LongType
- DecimalType(38,0)
- DoubleType
- StringType
- StringType
- StringType
-
-
-
- LongType
-
- LongType
- LongType
- LongType
- DecimalType(38,0)
- StringType
- StringType
- StringType
- StringType
-
-
-
- DecimalType(38,0)*
-
- DecimalType(38,0)
- DecimalType(38,0)
- DecimalType(38,0)
- DecimalType(38,0)
- StringType
- StringType
- StringType
- StringType
-
-
-
- DoubleType
-
- DoubleType
- DoubleType
- StringType
- StringType
- DoubleType
- StringType
- StringType
- StringType
-
-
-
- DateType
-
- DateType
- StringType
- StringType
- StringType
- StringType
- DateType
- TimestampType
- StringType
-
-
-
- TimestampType
-
- TimestampType
- StringType
- StringType
- StringType
- StringType
- TimestampType
- TimestampType
- StringType
-
-
-
- StringType
-
- StringType
- StringType
- StringType
- StringType
- StringType
- StringType
- StringType
- StringType
-
-
-
- Note that, for DecimalType(38,0)* , the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type.
-
- - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc.
-
- - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details.
-
- - In PySpark, `na.fill()` or `fillna` also accepts boolean and replaces nulls with booleans. In prior Spark versions, PySpark just ignores it and returns the original Dataset/DataFrame.
-
- - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](sql-performance-tuning.html#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489).
-
- - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`.
-
- - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`.
-
- - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant with SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes
-
- - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`).
-
- - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them.
-
- - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible.
-
- - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error-prone.
-
- - Un-aliased subquery's semantic has not been well defined with confusing behaviors. Since Spark 2.3, we invalidate such confusing cases, for example: `SELECT v.i from (SELECT i FROM v)`, Spark will throw an analysis exception in this case because users should not be able to use the qualifier inside a subquery. See [SPARK-20690](https://issues.apache.org/jira/browse/SPARK-20690) and [SPARK-21335](https://issues.apache.org/jira/browse/SPARK-21335) for more details.
-
- - When creating a `SparkSession` with `SparkSession.builder.getOrCreate()`, if there is an existing `SparkContext`, the builder was trying to update the `SparkConf` of the existing `SparkContext` with configurations specified to the builder, but the `SparkContext` is shared by all `SparkSession`s, so we should not update them. Since 2.3, the builder comes to not update the configurations. If you want to update them, you need to update them prior to creating a `SparkSession`.
-
-## Upgrading From Spark SQL 2.1 to 2.2
-
- - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time-consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access.
-
- - Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty).
-
- - Since Spark 2.2, view definitions are stored in a different way from prior versions. This may cause Spark unable to read views created by prior versions. In such cases, you need to recreate the views using `ALTER VIEW AS` or `CREATE OR REPLACE VIEW AS` with newer Spark versions.
-
-## Upgrading From Spark SQL 2.0 to 2.1
-
- - Datasource tables now store partition metadata in the Hive metastore. This means that Hive DDLs such as `ALTER TABLE PARTITION ... SET LOCATION` are now available for tables created with the Datasource API.
-
- - Legacy datasource tables can be migrated to this format via the `MSCK REPAIR TABLE` command. Migrating legacy tables is recommended to take advantage of Hive DDL support and improved planning performance.
-
- - To determine if a table has been migrated, look for the `PartitionProvider: Catalog` attribute when issuing `DESCRIBE FORMATTED` on the table.
- - Changes to `INSERT OVERWRITE TABLE ... PARTITION ...` behavior for Datasource tables.
-
- - In prior Spark versions `INSERT OVERWRITE` overwrote the entire Datasource table, even when given a partition specification. Now only partitions matching the specification are overwritten.
-
- - Note that this still differs from the behavior of Hive tables, which is to overwrite only partitions overlapping with newly inserted data.
-
-## Upgrading From Spark SQL 1.6 to 2.0
-
- - `SparkSession` is now the new entry point of Spark that replaces the old `SQLContext` and
-
- `HiveContext`. Note that the old SQLContext and HiveContext are kept for backward compatibility. A new `catalog` interface is accessible from `SparkSession` - existing API on databases and tables access such as `listTables`, `createExternalTable`, `dropTempView`, `cacheTable` are moved here.
-
- - Dataset API and DataFrame API are unified. In Scala, `DataFrame` becomes a type alias for
- `Dataset[Row]`, while Java API users must replace `DataFrame` with `Dataset`. Both the typed
- transformations (e.g., `map`, `filter`, and `groupByKey`) and untyped transformations (e.g.,
- `select` and `groupBy`) are available on the Dataset class. Since compile-time type-safety in
- Python and R is not a language feature, the concept of Dataset does not apply to these languages’
- APIs. Instead, `DataFrame` remains the primary programming abstraction, which is analogous to the
- single-node data frame notion in these languages.
-
- - Dataset and DataFrame API `unionAll` has been deprecated and replaced by `union`
-
- - Dataset and DataFrame API `explode` has been deprecated, alternatively, use `functions.explode()` with `select` or `flatMap`
-
- - Dataset and DataFrame API `registerTempTable` has been deprecated and replaced by `createOrReplaceTempView`
-
- - Changes to `CREATE TABLE ... LOCATION` behavior for Hive tables.
-
- - From Spark 2.0, `CREATE TABLE ... LOCATION` is equivalent to `CREATE EXTERNAL TABLE ... LOCATION`
- in order to prevent accidental dropping the existing data in the user-provided locations.
- That means, a Hive table created in Spark SQL with the user-specified location is always a Hive external table.
- Dropping external tables will not remove the data. Users are not allowed to specify the location for Hive managed tables.
- Note that this is different from the Hive behavior.
-
- - As a result, `DROP TABLE` statements on those tables will not remove the data.
-
- - `spark.sql.parquet.cacheMetadata` is no longer used.
- See [SPARK-13664](https://issues.apache.org/jira/browse/SPARK-13664) for details.
-
-## Upgrading From Spark SQL 1.5 to 1.6
-
- - From Spark 1.6, by default, the Thrift server runs in multi-session mode. Which means each JDBC/ODBC
- connection owns a copy of their own SQL configuration and temporary function registry. Cached
- tables are still shared though. If you prefer to run the Thrift server in the old single-session
- mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add
- this option to `spark-defaults.conf`, or pass it to `start-thriftserver.sh` via `--conf`:
-
- {% highlight bash %}
- ./sbin/start-thriftserver.sh \
- --conf spark.sql.hive.thriftServer.singleSession=true \
- ...
- {% endhighlight %}
-
- - Since 1.6.1, withColumn method in sparkR supports adding a new column to or replacing existing columns
- of the same name of a DataFrame.
-
- - From Spark 1.6, LongType casts to TimestampType expect seconds instead of microseconds. This
- change was made to match the behavior of Hive 1.2 for more consistent type casting to TimestampType
- from numeric types. See [SPARK-11724](https://issues.apache.org/jira/browse/SPARK-11724) for
- details.
-
-## Upgrading From Spark SQL 1.4 to 1.5
-
- - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with
- code generation for expression evaluation. These features can both be disabled by setting
- `spark.sql.tungsten.enabled` to `false`.
-
- - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting
- `spark.sql.parquet.mergeSchema` to `true`.
-
- - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or
- access nested values. For example `df['table.column.nestedField']`. However, this means that if
- your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``).
-
- - In-memory columnar storage partition pruning is on by default. It can be disabled by setting
- `spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`.
-
- - Unlimited precision decimal columns are no longer supported, instead Spark SQL enforces a maximum
- precision of 38. When inferring schema from `BigDecimal` objects, a precision of (38, 18) is now
- used. When no precision is specified in DDL then the default remains `Decimal(10, 0)`.
-
- - Timestamps are now stored at a precision of 1us, rather than 1ns
-
- - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains
- unchanged.
-
- - The canonical name of SQL/DataFrame functions are now lower case (e.g., sum vs SUM).
-
- - JSON data source will not automatically load new files that are created by other applications
- (i.e. files that are not inserted to the dataset through Spark SQL).
- For a JSON persistent table (i.e. the metadata of the table is stored in Hive Metastore),
- users can use `REFRESH TABLE` SQL command or `HiveContext`'s `refreshTable` method
- to include those new files to the table. For a DataFrame representing a JSON dataset, users need to recreate
- the DataFrame and the new DataFrame will include new files.
-
- - DataFrame.withColumn method in pySpark supports adding a new column or replacing existing columns of the same name.
-
-## Upgrading from Spark SQL 1.3 to 1.4
-
-#### DataFrame data reader/writer interface
-
-Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`)
-and writing data out (`DataFrame.write`),
-and deprecated the old APIs (e.g., `SQLContext.parquetFile`, `SQLContext.jsonFile`).
-
-See the API docs for `SQLContext.read` (
- Scala ,
- Java ,
- Python
-) and `DataFrame.write` (
- Scala ,
- Java ,
- Python
-) more information.
-
-
-#### DataFrame.groupBy retains grouping columns
-
-Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the
-grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`.
-
-
-
-{% highlight scala %}
-
-// In 1.3.x, in order for the grouping column "department" to show up,
-// it must be included explicitly as part of the agg function call.
-df.groupBy("department").agg($"department", max("age"), sum("expense"))
-
-// In 1.4+, grouping column "department" is included automatically.
-df.groupBy("department").agg(max("age"), sum("expense"))
-
-// Revert to 1.3 behavior (not retaining grouping column) by:
-sqlContext.setConf("spark.sql.retainGroupColumns", "false")
-
-{% endhighlight %}
-
-
-
-{% highlight java %}
-
-// In 1.3.x, in order for the grouping column "department" to show up,
-// it must be included explicitly as part of the agg function call.
-df.groupBy("department").agg(col("department"), max("age"), sum("expense"));
-
-// In 1.4+, grouping column "department" is included automatically.
-df.groupBy("department").agg(max("age"), sum("expense"));
-
-// Revert to 1.3 behavior (not retaining grouping column) by:
-sqlContext.setConf("spark.sql.retainGroupColumns", "false");
-
-{% endhighlight %}
-
-
-
-{% highlight python %}
-
-import pyspark.sql.functions as func
-
-# In 1.3.x, in order for the grouping column "department" to show up,
-# it must be included explicitly as part of the agg function call.
-df.groupBy("department").agg(df["department"], func.max("age"), func.sum("expense"))
-
-# In 1.4+, grouping column "department" is included automatically.
-df.groupBy("department").agg(func.max("age"), func.sum("expense"))
-
-# Revert to 1.3.x behavior (not retaining grouping column) by:
-sqlContext.setConf("spark.sql.retainGroupColumns", "false")
-
-{% endhighlight %}
-
-
-
-
-
-#### Behavior change on DataFrame.withColumn
-
-Prior to 1.4, DataFrame.withColumn() supports adding a column only. The column will always be added
-as a new column with its specified name in the result DataFrame even if there may be any existing
-columns of the same name. Since 1.4, DataFrame.withColumn() supports adding a column of a different
-name from names of all existing columns or replacing existing columns of the same name.
-
-Note that this change is only for Scala API, not for PySpark and SparkR.
-
-
-## Upgrading from Spark SQL 1.0-1.2 to 1.3
-
-In Spark 1.3 we removed the "Alpha" label from Spark SQL and as part of this did a cleanup of the
-available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other
-releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked
-as unstable (i.e., DeveloperAPI or Experimental).
-
-#### Rename of SchemaRDD to DataFrame
-
-The largest change that users will notice when upgrading to Spark SQL 1.3 is that `SchemaRDD` has
-been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD
-directly, but instead provide most of the functionality that RDDs provide though their own
-implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method.
-
-In Scala, there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for
-some use cases. It is still recommended that users update their code to use `DataFrame` instead.
-Java and Python users will need to update their code.
-
-#### Unification of the Java and Scala APIs
-
-Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`)
-that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users
-of either language should use `SQLContext` and `DataFrame`. In general these classes try to
-use types that are usable from both languages (i.e. `Array` instead of language-specific collections).
-In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading
-is used instead.
-
-Additionally, the Java specific types API has been removed. Users of both Scala and Java should
-use the classes present in `org.apache.spark.sql.types` to describe schema programmatically.
-
-
-#### Isolation of Implicit Conversions and Removal of dsl Package (Scala-only)
-
-Many of the code examples prior to Spark 1.3 started with `import sqlContext._`, which brought
-all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit
-conversions for converting `RDD`s into `DataFrame`s into an object inside of the `SQLContext`.
-Users should now write `import sqlContext.implicits._`.
-
-Additionally, the implicit conversions now only augment RDDs that are composed of `Product`s (i.e.,
-case classes or tuples) with a method `toDF`, instead of applying automatically.
-
-When using function inside of the DSL (now replaced with the `DataFrame` API) users used to import
-`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used:
-`import org.apache.spark.sql.functions._`.
-
-#### Removal of the type aliases in org.apache.spark.sql for DataType (Scala-only)
-
-Spark 1.3 removes the type aliases that were present in the base sql package for `DataType`. Users
-should instead import the classes in `org.apache.spark.sql.types`
-
-#### UDF Registration Moved to `sqlContext.udf` (Java & Scala)
-
-Functions that are used to register UDFs, either for use in the DataFrame DSL or SQL, have been
-moved into the udf object in `SQLContext`.
-
-
-
-{% highlight scala %}
-
-sqlContext.udf.register("strLen", (s: String) => s.length())
-
-{% endhighlight %}
-
-
-
-{% highlight java %}
-
-sqlContext.udf().register("strLen", (String s) -> s.length(), DataTypes.IntegerType);
-
-{% endhighlight %}
-
-
-
-
-Python UDF registration is unchanged.
-
-#### Python DataTypes No Longer Singletons
-
-When using DataTypes in Python you will need to construct them (i.e. `StringType()`) instead of
-referencing a singleton.
diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md
index 4c23147106b65..b52a57acdd7bc 100644
--- a/docs/sql-migration-guide.md
+++ b/docs/sql-migration-guide.md
@@ -1,7 +1,7 @@
---
layout: global
-title: Migration Guide
-displayTitle: Migration Guide
+title: "Migration Guide: SQL, Datasets and DataFrame"
+displayTitle: "Migration Guide: SQL, Datasets and DataFrame"
license: |
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
@@ -19,20 +19,906 @@ license: |
limitations under the License.
---
-* [Spark SQL Upgrading Guide](sql-migration-guide-upgrade.html)
- * [Upgrading From Spark SQL 2.4 to 3.0](sql-migration-guide-upgrade.html#upgrading-from-spark-sql-24-to-30)
- * [Upgrading From Spark SQL 2.3 to 2.4](sql-migration-guide-upgrade.html#upgrading-from-spark-sql-23-to-24)
- * [Upgrading From Spark SQL 2.3.0 to 2.3.1 and above](sql-migration-guide-upgrade.html#upgrading-from-spark-sql-230-to-231-and-above)
- * [Upgrading From Spark SQL 2.2 to 2.3](sql-migration-guide-upgrade.html#upgrading-from-spark-sql-22-to-23)
- * [Upgrading From Spark SQL 2.1 to 2.2](sql-migration-guide-upgrade.html#upgrading-from-spark-sql-21-to-22)
- * [Upgrading From Spark SQL 2.0 to 2.1](sql-migration-guide-upgrade.html#upgrading-from-spark-sql-20-to-21)
- * [Upgrading From Spark SQL 1.6 to 2.0](sql-migration-guide-upgrade.html#upgrading-from-spark-sql-16-to-20)
- * [Upgrading From Spark SQL 1.5 to 1.6](sql-migration-guide-upgrade.html#upgrading-from-spark-sql-15-to-16)
- * [Upgrading From Spark SQL 1.4 to 1.5](sql-migration-guide-upgrade.html#upgrading-from-spark-sql-14-to-15)
- * [Upgrading from Spark SQL 1.3 to 1.4](sql-migration-guide-upgrade.html#upgrading-from-spark-sql-13-to-14)
- * [Upgrading from Spark SQL 1.0-1.2 to 1.3](sql-migration-guide-upgrade.html#upgrading-from-spark-sql-10-12-to-13)
-* [Compatibility with Apache Hive](sql-migration-guide-hive-compatibility.html)
- * [Deploying in Existing Hive Warehouses](sql-migration-guide-hive-compatibility.html#deploying-in-existing-hive-warehouses)
- * [Supported Hive Features](sql-migration-guide-hive-compatibility.html#supported-hive-features)
- * [Unsupported Hive Functionality](sql-migration-guide-hive-compatibility.html#unsupported-hive-functionality)
- * [Incompatible Hive UDF](sql-migration-guide-hive-compatibility.html#incompatible-hive-udf)
+* Table of contents
+{:toc}
+
+## Upgrading from Spark SQL 2.4 to 3.0
+
+ - In Spark 3.0, the deprecated methods `SQLContext.createExternalTable` and `SparkSession.createExternalTable` have been removed in favor of its replacement, `createTable`.
+
+ - In Spark 3.0, the deprecated `HiveContext` class has been removed. Use `SparkSession.builder.enableHiveSupport()` instead.
+
+ - Since Spark 3.0, configuration `spark.sql.crossJoin.enabled` become internal configuration, and is true by default, so by default spark won't raise exception on sql with implicit cross join.
+
+ - Since Spark 3.0, we reversed argument order of the trim function from `TRIM(trimStr, str)` to `TRIM(str, trimStr)` to be compatible with other databases.
+
+ - In Spark version 2.4 and earlier, SQL queries such as `FROM ` or `FROM UNION ALL FROM ` are supported by accident. In hive-style `FROM SELECT `, the `SELECT` clause is not negligible. Neither Hive nor Presto support this syntax. Therefore we will treat these queries as invalid since Spark 3.0.
+
+ - Since Spark 3.0, the Dataset and DataFrame API `unionAll` is not deprecated any more. It is an alias for `union`.
+
+ - In Spark version 2.4 and earlier, the parser of JSON data source treats empty strings as null for some data types such as `IntegerType`. For `FloatType` and `DoubleType`, it fails on empty strings and throws exceptions. Since Spark 3.0, we disallow empty strings and will throw exceptions for data types except for `StringType` and `BinaryType`.
+
+ - Since Spark 3.0, the `from_json` functions supports two modes - `PERMISSIVE` and `FAILFAST`. The modes can be set via the `mode` option. The default mode became `PERMISSIVE`. In previous versions, behavior of `from_json` did not conform to either `PERMISSIVE` nor `FAILFAST`, especially in processing of malformed JSON records. For example, the JSON string `{"a" 1}` with the schema `a INT` is converted to `null` by previous versions but Spark 3.0 converts it to `Row(null)`.
+
+ - The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set.
+
+ - In Spark version 2.4 and earlier, users can create map values with map type key via built-in function like `CreateMap`, `MapFromArrays`, etc. Since Spark 3.0, it's not allowed to create map values with map type key with these built-in functions. Users can still read map values with map type key from data source or Java/Scala collections, though they are not very useful.
+
+ - In Spark version 2.4 and earlier, `Dataset.groupByKey` results to a grouped dataset with key attribute wrongly named as "value", if the key is non-struct type, e.g. int, string, array, etc. This is counterintuitive and makes the schema of aggregation queries weird. For example, the schema of `ds.groupByKey(...).count()` is `(value, count)`. Since Spark 3.0, we name the grouping attribute to "key". The old behaviour is preserved under a newly added configuration `spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue` with a default value of `false`.
+
+ - In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but -0.0 and 0.0 are considered as different values when used in aggregate grouping keys, window partition keys and join keys. Since Spark 3.0, this bug is fixed. For example, `Seq(-0.0, 0.0).toDF("d").groupBy("d").count()` returns `[(0.0, 2)]` in Spark 3.0, and `[(0.0, 1), (-0.0, 1)]` in Spark 2.4 and earlier.
+
+ - In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be undefined.
+
+ - In Spark version 2.4 and earlier, partition column value is converted as null if it can't be casted to corresponding user provided schema. Since 3.0, partition column value is validated with user provided schema. An exception is thrown if the validation fails. You can disable such validation by setting `spark.sql.sources.validatePartitionColumns` to `false`.
+
+ - In Spark version 2.4 and earlier, the `SET` command works without any warnings even if the specified key is for `SparkConf` entries and it has no effect because the command does not update `SparkConf`, but the behavior might confuse users. Since 3.0, the command fails if a `SparkConf` key is used. You can disable such a check by setting `spark.sql.legacy.setCommandRejectsSparkCoreConfs` to `false`.
+
+ - In Spark version 2.4 and earlier, CSV datasource converts a malformed CSV string to a row with all `null`s in the PERMISSIVE mode. Since Spark 3.0, the returned row can contain non-`null` fields if some of CSV column values were parsed and converted to desired types successfully.
+
+ - In Spark version 2.4 and earlier, JSON datasource and JSON functions like `from_json` convert a bad JSON record to a row with all `null`s in the PERMISSIVE mode when specified schema is `StructType`. Since Spark 3.0, the returned row can contain non-`null` fields if some of JSON column values were parsed and converted to desired types successfully.
+
+ - Refreshing a cached table would trigger a table uncache operation and then a table cache (lazily) operation. In Spark version 2.4 and earlier, the cache name and storage level are not preserved before the uncache operation. Therefore, the cache name and storage level could be changed unexpectedly. Since Spark 3.0, cache name and storage level will be first preserved for cache recreation. It helps to maintain a consistent cache behavior upon table refreshing.
+
+ - Since Spark 3.0, JSON datasource and JSON function `schema_of_json` infer TimestampType from string values if they match to the pattern defined by the JSON option `timestampFormat`. Set JSON option `inferTimestamp` to `false` to disable such type inferring.
+
+ - In Spark version 2.4 and earlier, if `org.apache.spark.sql.functions.udf(Any, DataType)` gets a Scala closure with primitive-type argument, the returned UDF will return null if the input values is null. Since Spark 3.0, the UDF will return the default value of the Java type if the input value is null. For example, `val f = udf((x: Int) => x, IntegerType)`, `f($"x")` will return null in Spark 2.4 and earlier if column `x` is null, and return 0 in Spark 3.0. This behavior change is introduced because Spark 3.0 is built with Scala 2.12 by default.
+
+ - Since Spark 3.0, Proleptic Gregorian calendar is used in parsing, formatting, and converting dates and timestamps as well as in extracting sub-components like years, days and etc. Spark 3.0 uses Java 8 API classes from the java.time packages that based on ISO chronology (https://docs.oracle.com/javase/8/docs/api/java/time/chrono/IsoChronology.html). In Spark version 2.4 and earlier, those operations are performed by using the hybrid calendar (Julian + Gregorian, see https://docs.oracle.com/javase/7/docs/api/java/util/GregorianCalendar.html). The changes impact on the results for dates before October 15, 1582 (Gregorian) and affect on the following Spark 3.0 API:
+
+ - CSV/JSON datasources use java.time API for parsing and generating CSV/JSON content. In Spark version 2.4 and earlier, java.text.SimpleDateFormat is used for the same purpose with fallbacks to the parsing mechanisms of Spark 2.0 and 1.x. For example, `2018-12-08 10:39:21.123` with the pattern `yyyy-MM-dd'T'HH:mm:ss.SSS` cannot be parsed since Spark 3.0 because the timestamp does not match to the pattern but it can be parsed by earlier Spark versions due to a fallback to `Timestamp.valueOf`. To parse the same timestamp since Spark 3.0, the pattern should be `yyyy-MM-dd HH:mm:ss.SSS`.
+
+ - The `unix_timestamp`, `date_format`, `to_unix_timestamp`, `from_unixtime`, `to_date`, `to_timestamp` functions. New implementation supports pattern formats as described here https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html and performs strict checking of its input. For example, the `2015-07-22 10:00:00` timestamp cannot be parse if pattern is `yyyy-MM-dd` because the parser does not consume whole input. Another example is the `31/01/2015 00:00` input cannot be parsed by the `dd/MM/yyyy hh:mm` pattern because `hh` supposes hours in the range `1-12`.
+
+ - The `weekofyear`, `weekday`, `dayofweek`, `date_trunc`, `from_utc_timestamp`, `to_utc_timestamp`, and `unix_timestamp` functions use java.time API for calculation week number of year, day number of week as well for conversion from/to TimestampType values in UTC time zone.
+
+ - the JDBC options `lowerBound` and `upperBound` are converted to TimestampType/DateType values in the same way as casting strings to TimestampType/DateType values. The conversion is based on Proleptic Gregorian calendar, and time zone defined by the SQL config `spark.sql.session.timeZone`. In Spark version 2.4 and earlier, the conversion is based on the hybrid calendar (Julian + Gregorian) and on default system time zone.
+
+ - Formatting of `TIMESTAMP` and `DATE` literals.
+
+ - In Spark version 2.4 and earlier, invalid time zone ids are silently ignored and replaced by GMT time zone, for example, in the from_utc_timestamp function. Since Spark 3.0, such time zone ids are rejected, and Spark throws `java.time.DateTimeException`.
+
+ - In Spark version 2.4 and earlier, the `current_timestamp` function returns a timestamp with millisecond resolution only. Since Spark 3.0, the function can return the result with microsecond resolution if the underlying clock available on the system offers such resolution.
+
+ - In Spark version 2.4 and earlier, when reading a Hive Serde table with Spark native data sources(parquet/orc), Spark will infer the actual file schema and update the table schema in metastore. Since Spark 3.0, Spark doesn't infer the schema anymore. This should not cause any problems to end users, but if it does, please set `spark.sql.hive.caseSensitiveInferenceMode` to `INFER_AND_SAVE`.
+
+ - Since Spark 3.0, `TIMESTAMP` literals are converted to strings using the SQL config `spark.sql.session.timeZone`. In Spark version 2.4 and earlier, the conversion uses the default time zone of the Java virtual machine.
+
+ - In Spark version 2.4, when a spark session is created via `cloneSession()`, the newly created spark session inherits its configuration from its parent `SparkContext` even though the same configuration may exist with a different value in its parent spark session. Since Spark 3.0, the configurations of a parent `SparkSession` have a higher precedence over the parent `SparkContext`. The old behavior can be restored by setting `spark.sql.legacy.sessionInitWithConfigDefaults` to `true`.
+
+ - Since Spark 3.0, parquet logical type `TIMESTAMP_MICROS` is used by default while saving `TIMESTAMP` columns. In Spark version 2.4 and earlier, `TIMESTAMP` columns are saved as `INT96` in parquet files. To set `INT96` to `spark.sql.parquet.outputTimestampType` restores the previous behavior.
+
+ - Since Spark 3.0, if `hive.default.fileformat` is not found in `Spark SQL configuration` then it will fallback to hive-site.xml present in the `Hadoop configuration` of `SparkContext`.
+
+ - Since Spark 3.0, Spark will cast `String` to `Date/TimeStamp` in binary comparisons with dates/timestamps. The previous behaviour of casting `Date/Timestamp` to `String` can be restored by setting `spark.sql.legacy.typeCoercion.datetimeToString` to `true`.
+
+ - Since Spark 3.0, when Avro files are written with user provided schema, the fields will be matched by field names between catalyst schema and avro schema instead of positions.
+
+ - Since Spark 3.0, when Avro files are written with user provided non-nullable schema, even the catalyst schema is nullable, Spark is still able to write the files. However, Spark will throw runtime NPE if any of the records contains null.
+
+ - Since Spark 3.0, we use a new protocol for fetching shuffle blocks, for external shuffle service users, we need to upgrade the server correspondingly. Otherwise, we'll get the error message `UnsupportedOperationException: Unexpected message: FetchShuffleBlocks`. If it is hard to upgrade the shuffle service right now, you can still use the old protocol by setting `spark.shuffle.useOldFetchProtocol` to `true`.
+
+ - Since Spark 3.0, a higher-order function `exists` follows the three-valued boolean logic, i.e., if the `predicate` returns any `null`s and no `true` is obtained, then `exists` will return `null` instead of `false`. For example, `exists(array(1, null, 3), x -> x % 2 == 0)` will be `null`. The previous behaviour can be restored by setting `spark.sql.legacy.arrayExistsFollowsThreeValuedLogic` to `false`.
+
+ - Since Spark 3.0, if files or subdirectories disappear during recursive directory listing (i.e. they appear in an intermediate listing but then cannot be read or listed during later phases of the recursive directory listing, due to either concurrent file deletions or object store consistency issues) then the listing will fail with an exception unless `spark.sql.files.ignoreMissingFiles` is `true` (default `false`). In previous versions, these missing files or subdirectories would be ignored. Note that this change of behavior only applies during initial table file listing (or during `REFRESH TABLE`), not during query execution: the net change is that `spark.sql.files.ignoreMissingFiles` is now obeyed during table file listing / query planning, not only at query execution time.
+
+ - Since Spark 3.0, substitution order of nested WITH clauses is changed and an inner CTE definition takes precedence over an outer. In version 2.4 and earlier, `WITH t AS (SELECT 1), t2 AS (WITH t AS (SELECT 2) SELECT * FROM t) SELECT * FROM t2` returns `1` while in version 3.0 it returns `2`. The previous behaviour can be restored by setting `spark.sql.legacy.ctePrecedence.enabled` to `true`.
+
+ - Since Spark 3.0, the `add_months` function does not adjust the resulting date to a last day of month if the original date is a last day of months. For example, `select add_months(DATE'2019-02-28', 1)` results `2019-03-28`. In Spark version 2.4 and earlier, the resulting date is adjusted when the original date is a last day of months. For example, adding a month to `2019-02-28` results in `2019-03-31`.
+
+ - Since Spark 3.0, 0-argument Java UDF is executed in the executor side identically with other UDFs. In Spark version 2.4 and earlier, 0-argument Java UDF alone was executed in the driver side, and the result was propagated to executors, which might be more performant in some cases but caused inconsistency with a correctness issue in some cases.
+
+ - The result of `java.lang.Math`'s `log`, `log1p`, `exp`, `expm1`, and `pow` may vary across platforms. In Spark 3.0, the result of the equivalent SQL functions (including related SQL functions like `LOG10`) return values consistent with `java.lang.StrictMath`. In virtually all cases this makes no difference in the return value, and the difference is very small, but may not exactly match `java.lang.Math` on x86 platforms in cases like, for example, `log(3.0)`, whose value varies between `Math.log()` and `StrictMath.log()`.
+
+ - Since Spark 3.0, Dataset query fails if it contains ambiguous column reference that is caused by self join. A typical example: `val df1 = ...; val df2 = df1.filter(...);`, then `df1.join(df2, df1("a") > df2("a"))` returns an empty result which is quite confusing. This is because Spark cannot resolve Dataset column references that point to tables being self joined, and `df1("a")` is exactly the same as `df2("a")` in Spark. To restore the behavior before Spark 3.0, you can set `spark.sql.analyzer.failAmbiguousSelfJoin` to `false`.
+
+ - Since Spark 3.0, `Cast` function processes string literals such as 'Infinity', '+Infinity', '-Infinity', 'NaN', 'Inf', '+Inf', '-Inf' in case insensitive manner when casting the literals to `Double` or `Float` type to ensure greater compatibility with other database systems. This behaviour change is illustrated in the table below:
+
+
+
+ Operation
+
+
+ Result prior to Spark 3.0
+
+
+ Result starting Spark 3.0
+
+
+
+
+ CAST('infinity' AS DOUBLE)
+ CAST('+infinity' AS DOUBLE)
+ CAST('inf' AS DOUBLE)
+ CAST('+inf' AS DOUBLE)
+
+
+ NULL
+
+
+ Double.PositiveInfinity
+
+
+
+
+ CAST('-infinity' AS DOUBLE)
+ CAST('-inf' AS DOUBLE)
+
+
+ NULL
+
+
+ Double.NegativeInfinity
+
+
+
+
+ CAST('infinity' AS FLOAT)
+ CAST('+infinity' AS FLOAT)
+ CAST('inf' AS FLOAT)
+ CAST('+inf' AS FLOAT)
+
+
+ NULL
+
+
+ Float.PositiveInfinity
+
+
+
+
+ CAST('-infinity' AS FLOAT)
+ CAST('-inf' AS FLOAT)
+
+
+ NULL
+
+
+ Float.NegativeInfinity
+
+
+
+
+ CAST('nan' AS DOUBLE)
+
+
+ NULL
+
+
+ Double.NaN
+
+
+
+
+ CAST('nan' AS FLOAT)
+
+
+ NULL
+
+
+ Float.NaN
+
+
+
+
+ - Since Spark 3.0, special values are supported in conversion from strings to dates and timestamps. Those values are simply notational shorthands that will be converted to ordinary date or timestamp values when read. The following string values are supported for dates:
+ - `epoch [zoneId]` - 1970-01-01
+ - `today [zoneId]` - the current date in the time zone specified by `spark.sql.session.timeZone`
+ - `yesterday [zoneId]` - the current date - 1
+ - `tomorrow [zoneId]` - the current date + 1
+ - `now` - the date of running the current query. It has the same notion as today
+ For example `SELECT date 'tomorrow' - date 'yesterday';` should output `2`. Here are special timestamp values:
+ - `epoch [zoneId]` - 1970-01-01 00:00:00+00 (Unix system time zero)
+ - `today [zoneId]` - midnight today
+ - `yesterday [zoneId]` - midnight yesterday
+ - `tomorrow [zoneId]` - midnight tomorrow
+ - `now` - current query start time
+ For example `SELECT timestamp 'tomorrow';`.
+
+## Upgrading from Spark SQL 2.4 to 2.4.1
+
+ - The value of `spark.executor.heartbeatInterval`, when specified without units like "30" rather than "30s", was
+ inconsistently interpreted as both seconds and milliseconds in Spark 2.4.0 in different parts of the code.
+ Unitless values are now consistently interpreted as milliseconds. Applications that set values like "30"
+ need to specify a value with units like "30s" now, to avoid being interpreted as milliseconds; otherwise,
+ the extremely short interval that results will likely cause applications to fail.
+
+ - When turning a Dataset to another Dataset, Spark will up cast the fields in the original Dataset to the type of corresponding fields in the target DataSet. In version 2.4 and earlier, this up cast is not very strict, e.g. `Seq("str").toDS.as[Int]` fails, but `Seq("str").toDS.as[Boolean]` works and throw NPE during execution. In Spark 3.0, the up cast is stricter and turning String into something else is not allowed, i.e. `Seq("str").toDS.as[Boolean]` will fail during analysis.
+
+## Upgrading from Spark SQL 2.3 to 2.4
+
+ - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below.
+
+
+
+ Query
+
+
+ Spark 2.3 or Prior
+
+
+ Spark 2.4
+
+
+ Remarks
+
+
+
+
+ SELECT array_contains(array(1), 1.34D);
+
+
+ true
+
+
+ false
+
+
+ In Spark 2.4, left and right parameters are promoted to array type of double type and double type respectively.
+
+
+
+
+ SELECT array_contains(array(1), '1');
+
+
+ true
+
+
+ AnalysisException
is thrown.
+
+
+ Explicit cast can be used in arguments to avoid the exception. In Spark 2.4, AnalysisException
is thrown since integer type can not be promoted to string type in a loss-less manner.
+
+
+
+
+ SELECT array_contains(array(1), 'anystring');
+
+
+ null
+
+
+ AnalysisException
is thrown.
+
+
+ Explicit cast can be used in arguments to avoid the exception. In Spark 2.4, AnalysisException
is thrown since integer type can not be promoted to string type in a loss-less manner.
+
+
+
+
+ - Since Spark 2.4, when there is a struct field in front of the IN operator before a subquery, the inner query must contain a struct field as well. In previous versions, instead, the fields of the struct were compared to the output of the inner query. Eg. if `a` is a `struct(a string, b int)`, in Spark 2.4 `a in (select (1 as a, 'a' as b) from range(1))` is a valid query, while `a in (select 1, 'a' from range(1))` is not. In previous version it was the opposite.
+
+ - In versions 2.2.1+ and 2.3, if `spark.sql.caseSensitive` is set to true, then the `CURRENT_DATE` and `CURRENT_TIMESTAMP` functions incorrectly became case-sensitive and would resolve to columns (unless typed in lower case). In Spark 2.4 this has been fixed and the functions are no longer case-sensitive.
+
+ - Since Spark 2.4, Spark will evaluate the set operations referenced in a query by following a precedence rule as per the SQL standard. If the order is not specified by parentheses, set operations are performed from left to right with the exception that all INTERSECT operations are performed before any UNION, EXCEPT or MINUS operations. The old behaviour of giving equal precedence to all the set operations are preserved under a newly added configuration `spark.sql.legacy.setopsPrecedence.enabled` with a default value of `false`. When this property is set to `true`, spark will evaluate the set operators from left to right as they appear in the query given no explicit ordering is enforced by usage of parenthesis.
+
+ - Since Spark 2.4, Spark will display table description column Last Access value as UNKNOWN when the value was Jan 01 1970.
+
+ - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. ORC files created by native ORC writer cannot be read by some old Apache Hive releases. Use `spark.sql.orc.impl=hive` to create the files shared with Hive 2.1.1 and older.
+
+ - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe.
+
+ - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, a column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``.
+
+ - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema.
+
+ - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.legacy.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0.
+
+ - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0.
+
+ - Since Spark 2.4, renaming a managed table to existing location is not allowed. An exception is thrown when attempting to rename a managed table to existing location.
+
+ - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception.
+
+ - Since Spark 2.4, Spark has enabled non-cascading SQL cache invalidation in addition to the traditional cache invalidation mechanism. The non-cascading cache invalidation mechanism allows users to remove a cache without impacting its dependent caches. This new cache invalidation mechanism is used in scenarios where the data of the cache to be removed is still valid, e.g., calling unpersist() on a Dataset, or dropping a temporary view. This allows users to free up memory and keep the desired caches valid at the same time.
+
+ - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files.
+
+ - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior.
+
+ - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`.
+
+ - Since Spark 2.4, File listing for compute statistics is done in parallel by default. This can be disabled by setting `spark.sql.statistics.parallelFileListingInStatsComputation.enabled` to `False`.
+
+ - Since Spark 2.4, Metadata files (e.g. Parquet summary files) and temporary files are not counted as data files when calculating table size during Statistics computation.
+
+ - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was written as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string.
+
+ - Since Spark 2.4, The LOAD DATA command supports wildcard `?` and `*`, which match any one character, and zero or more characters, respectively. Example: `LOAD DATA INPATH '/tmp/folder*/'` or `LOAD DATA INPATH '/tmp/part-?'`. Special Characters like `space` also now work in paths. Example: `LOAD DATA INPATH '/tmp/folder name/'`.
+
+ - In Spark version 2.3 and earlier, HAVING without GROUP BY is treated as WHERE. This means, `SELECT 1 FROM range(10) HAVING true` is executed as `SELECT 1 FROM range(10) WHERE true` and returns 10 rows. This violates SQL standard, and has been fixed in Spark 2.4. Since Spark 2.4, HAVING without GROUP BY is treated as a global aggregate, which means `SELECT 1 FROM range(10) HAVING true` will return only one row. To restore the previous behavior, set `spark.sql.legacy.parser.havingWithoutGroupByAsWhere` to `true`.
+
+ - In version 2.3 and earlier, when reading from a Parquet data source table, Spark always returns null for any column whose column names in Hive metastore schema and Parquet schema are in different letter cases, no matter whether `spark.sql.caseSensitive` is set to `true` or `false`. Since 2.4, when `spark.sql.caseSensitive` is set to `false`, Spark does case insensitive column name resolution between Hive metastore schema and Parquet schema, so even column names are in different letter cases, Spark returns corresponding column values. An exception is thrown if there is ambiguity, i.e. more than one Parquet column is matched. This change also applies to Parquet Hive tables when `spark.sql.hive.convertMetastoreParquet` is set to `true`.
+
+## Upgrading from Spark SQL 2.2 to 2.3
+
+ - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`.
+
+ - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles.
+
+ - Since Spark 2.3, the Join/Filter's deterministic predicates that are after the first non-deterministic predicates are also pushed down/through the child operators, if possible. In prior Spark versions, these filters are not eligible for predicate pushdown.
+
+ - Partition column inference previously found incorrect common type for different inferred types, for example, previously it ended up with double type as the common type for double type and date type. Now it finds the correct common type for such conflicts. The conflict resolution follows the table below:
+
+
+
+ InputA \ InputB
+
+
+ NullType
+
+
+ IntegerType
+
+
+ LongType
+
+
+ DecimalType(38,0)*
+
+
+ DoubleType
+
+
+ DateType
+
+
+ TimestampType
+
+
+ StringType
+
+
+
+
+ NullType
+
+ NullType
+ IntegerType
+ LongType
+ DecimalType(38,0)
+ DoubleType
+ DateType
+ TimestampType
+ StringType
+
+
+
+ IntegerType
+
+ IntegerType
+ IntegerType
+ LongType
+ DecimalType(38,0)
+ DoubleType
+ StringType
+ StringType
+ StringType
+
+
+
+ LongType
+
+ LongType
+ LongType
+ LongType
+ DecimalType(38,0)
+ StringType
+ StringType
+ StringType
+ StringType
+
+
+
+ DecimalType(38,0)*
+
+ DecimalType(38,0)
+ DecimalType(38,0)
+ DecimalType(38,0)
+ DecimalType(38,0)
+ StringType
+ StringType
+ StringType
+ StringType
+
+
+
+ DoubleType
+
+ DoubleType
+ DoubleType
+ StringType
+ StringType
+ DoubleType
+ StringType
+ StringType
+ StringType
+
+
+
+ DateType
+
+ DateType
+ StringType
+ StringType
+ StringType
+ StringType
+ DateType
+ TimestampType
+ StringType
+
+
+
+ TimestampType
+
+ TimestampType
+ StringType
+ StringType
+ StringType
+ StringType
+ TimestampType
+ TimestampType
+ StringType
+
+
+
+ StringType
+
+ StringType
+ StringType
+ StringType
+ StringType
+ StringType
+ StringType
+ StringType
+ StringType
+
+
+
+ Note that, for DecimalType(38,0)* , the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type.
+
+ - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](sql-performance-tuning.html#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489).
+
+ - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`.
+
+ - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`.
+
+ - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant with SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes
+
+ - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`).
+
+ - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them.
+
+ - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible.
+
+ - Un-aliased subquery's semantic has not been well defined with confusing behaviors. Since Spark 2.3, we invalidate such confusing cases, for example: `SELECT v.i from (SELECT i FROM v)`, Spark will throw an analysis exception in this case because users should not be able to use the qualifier inside a subquery. See [SPARK-20690](https://issues.apache.org/jira/browse/SPARK-20690) and [SPARK-21335](https://issues.apache.org/jira/browse/SPARK-21335) for more details.
+
+ - When creating a `SparkSession` with `SparkSession.builder.getOrCreate()`, if there is an existing `SparkContext`, the builder was trying to update the `SparkConf` of the existing `SparkContext` with configurations specified to the builder, but the `SparkContext` is shared by all `SparkSession`s, so we should not update them. Since 2.3, the builder comes to not update the configurations. If you want to update them, you need to update them prior to creating a `SparkSession`.
+
+## Upgrading from Spark SQL 2.1 to 2.2
+
+ - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time-consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access.
+
+ - Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty).
+
+ - Since Spark 2.2, view definitions are stored in a different way from prior versions. This may cause Spark unable to read views created by prior versions. In such cases, you need to recreate the views using `ALTER VIEW AS` or `CREATE OR REPLACE VIEW AS` with newer Spark versions.
+
+## Upgrading from Spark SQL 2.0 to 2.1
+
+ - Datasource tables now store partition metadata in the Hive metastore. This means that Hive DDLs such as `ALTER TABLE PARTITION ... SET LOCATION` are now available for tables created with the Datasource API.
+
+ - Legacy datasource tables can be migrated to this format via the `MSCK REPAIR TABLE` command. Migrating legacy tables is recommended to take advantage of Hive DDL support and improved planning performance.
+
+ - To determine if a table has been migrated, look for the `PartitionProvider: Catalog` attribute when issuing `DESCRIBE FORMATTED` on the table.
+ - Changes to `INSERT OVERWRITE TABLE ... PARTITION ...` behavior for Datasource tables.
+
+ - In prior Spark versions `INSERT OVERWRITE` overwrote the entire Datasource table, even when given a partition specification. Now only partitions matching the specification are overwritten.
+
+ - Note that this still differs from the behavior of Hive tables, which is to overwrite only partitions overlapping with newly inserted data.
+
+## Upgrading from Spark SQL 1.6 to 2.0
+
+ - `SparkSession` is now the new entry point of Spark that replaces the old `SQLContext` and
+
+ `HiveContext`. Note that the old SQLContext and HiveContext are kept for backward compatibility. A new `catalog` interface is accessible from `SparkSession` - existing API on databases and tables access such as `listTables`, `createExternalTable`, `dropTempView`, `cacheTable` are moved here.
+
+ - Dataset API and DataFrame API are unified. In Scala, `DataFrame` becomes a type alias for
+ `Dataset[Row]`, while Java API users must replace `DataFrame` with `Dataset`. Both the typed
+ transformations (e.g., `map`, `filter`, and `groupByKey`) and untyped transformations (e.g.,
+ `select` and `groupBy`) are available on the Dataset class. Since compile-time type-safety in
+ Python and R is not a language feature, the concept of Dataset does not apply to these languages’
+ APIs. Instead, `DataFrame` remains the primary programming abstraction, which is analogous to the
+ single-node data frame notion in these languages.
+
+ - Dataset and DataFrame API `unionAll` has been deprecated and replaced by `union`
+
+ - Dataset and DataFrame API `explode` has been deprecated, alternatively, use `functions.explode()` with `select` or `flatMap`
+
+ - Dataset and DataFrame API `registerTempTable` has been deprecated and replaced by `createOrReplaceTempView`
+
+ - Changes to `CREATE TABLE ... LOCATION` behavior for Hive tables.
+
+ - From Spark 2.0, `CREATE TABLE ... LOCATION` is equivalent to `CREATE EXTERNAL TABLE ... LOCATION`
+ in order to prevent accidental dropping the existing data in the user-provided locations.
+ That means, a Hive table created in Spark SQL with the user-specified location is always a Hive external table.
+ Dropping external tables will not remove the data. Users are not allowed to specify the location for Hive managed tables.
+ Note that this is different from the Hive behavior.
+
+ - As a result, `DROP TABLE` statements on those tables will not remove the data.
+
+ - `spark.sql.parquet.cacheMetadata` is no longer used.
+ See [SPARK-13664](https://issues.apache.org/jira/browse/SPARK-13664) for details.
+
+## Upgrading from Spark SQL 1.5 to 1.6
+
+ - From Spark 1.6, by default, the Thrift server runs in multi-session mode. Which means each JDBC/ODBC
+ connection owns a copy of their own SQL configuration and temporary function registry. Cached
+ tables are still shared though. If you prefer to run the Thrift server in the old single-session
+ mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add
+ this option to `spark-defaults.conf`, or pass it to `start-thriftserver.sh` via `--conf`:
+
+ {% highlight bash %}
+ ./sbin/start-thriftserver.sh \
+ --conf spark.sql.hive.thriftServer.singleSession=true \
+ ...
+ {% endhighlight %}
+
+ - From Spark 1.6, LongType casts to TimestampType expect seconds instead of microseconds. This
+ change was made to match the behavior of Hive 1.2 for more consistent type casting to TimestampType
+ from numeric types. See [SPARK-11724](https://issues.apache.org/jira/browse/SPARK-11724) for
+ details.
+
+## Upgrading from Spark SQL 1.4 to 1.5
+
+ - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with
+ code generation for expression evaluation. These features can both be disabled by setting
+ `spark.sql.tungsten.enabled` to `false`.
+
+ - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting
+ `spark.sql.parquet.mergeSchema` to `true`.
+
+ - In-memory columnar storage partition pruning is on by default. It can be disabled by setting
+ `spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`.
+
+ - Unlimited precision decimal columns are no longer supported, instead Spark SQL enforces a maximum
+ precision of 38. When inferring schema from `BigDecimal` objects, a precision of (38, 18) is now
+ used. When no precision is specified in DDL then the default remains `Decimal(10, 0)`.
+
+ - Timestamps are now stored at a precision of 1us, rather than 1ns
+
+ - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains
+ unchanged.
+
+ - The canonical name of SQL/DataFrame functions are now lower case (e.g., sum vs SUM).
+
+ - JSON data source will not automatically load new files that are created by other applications
+ (i.e. files that are not inserted to the dataset through Spark SQL).
+ For a JSON persistent table (i.e. the metadata of the table is stored in Hive Metastore),
+ users can use `REFRESH TABLE` SQL command or `HiveContext`'s `refreshTable` method
+ to include those new files to the table. For a DataFrame representing a JSON dataset, users need to recreate
+ the DataFrame and the new DataFrame will include new files.
+
+## Upgrading from Spark SQL 1.3 to 1.4
+
+#### DataFrame data reader/writer interface
+{:.no_toc}
+
+Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`)
+and writing data out (`DataFrame.write`),
+and deprecated the old APIs (e.g., `SQLContext.parquetFile`, `SQLContext.jsonFile`).
+
+See the API docs for `SQLContext.read` (
+ Scala ,
+ Java ,
+ Python
+) and `DataFrame.write` (
+ Scala ,
+ Java ,
+ Python
+) more information.
+
+
+#### DataFrame.groupBy retains grouping columns
+{:.no_toc}
+
+Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the
+grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`.
+
+
+
+{% highlight scala %}
+
+// In 1.3.x, in order for the grouping column "department" to show up,
+// it must be included explicitly as part of the agg function call.
+df.groupBy("department").agg($"department", max("age"), sum("expense"))
+
+// In 1.4+, grouping column "department" is included automatically.
+df.groupBy("department").agg(max("age"), sum("expense"))
+
+// Revert to 1.3 behavior (not retaining grouping column) by:
+sqlContext.setConf("spark.sql.retainGroupColumns", "false")
+
+{% endhighlight %}
+
+
+
+{% highlight java %}
+
+// In 1.3.x, in order for the grouping column "department" to show up,
+// it must be included explicitly as part of the agg function call.
+df.groupBy("department").agg(col("department"), max("age"), sum("expense"));
+
+// In 1.4+, grouping column "department" is included automatically.
+df.groupBy("department").agg(max("age"), sum("expense"));
+
+// Revert to 1.3 behavior (not retaining grouping column) by:
+sqlContext.setConf("spark.sql.retainGroupColumns", "false");
+
+{% endhighlight %}
+
+
+
+{% highlight python %}
+
+import pyspark.sql.functions as func
+
+# In 1.3.x, in order for the grouping column "department" to show up,
+# it must be included explicitly as part of the agg function call.
+df.groupBy("department").agg(df["department"], func.max("age"), func.sum("expense"))
+
+# In 1.4+, grouping column "department" is included automatically.
+df.groupBy("department").agg(func.max("age"), func.sum("expense"))
+
+# Revert to 1.3.x behavior (not retaining grouping column) by:
+sqlContext.setConf("spark.sql.retainGroupColumns", "false")
+
+{% endhighlight %}
+
+
+
+
+
+#### Behavior change on DataFrame.withColumn
+{:.no_toc}
+
+Prior to 1.4, DataFrame.withColumn() supports adding a column only. The column will always be added
+as a new column with its specified name in the result DataFrame even if there may be any existing
+columns of the same name. Since 1.4, DataFrame.withColumn() supports adding a column of a different
+name from names of all existing columns or replacing existing columns of the same name.
+
+Note that this change is only for Scala API, not for PySpark and SparkR.
+
+
+## Upgrading from Spark SQL 1.0-1.2 to 1.3
+
+In Spark 1.3 we removed the "Alpha" label from Spark SQL and as part of this did a cleanup of the
+available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other
+releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked
+as unstable (i.e., DeveloperAPI or Experimental).
+
+#### Rename of SchemaRDD to DataFrame
+{:.no_toc}
+
+The largest change that users will notice when upgrading to Spark SQL 1.3 is that `SchemaRDD` has
+been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD
+directly, but instead provide most of the functionality that RDDs provide though their own
+implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method.
+
+In Scala, there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for
+some use cases. It is still recommended that users update their code to use `DataFrame` instead.
+Java and Python users will need to update their code.
+
+#### Unification of the Java and Scala APIs
+{:.no_toc}
+
+Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`)
+that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users
+of either language should use `SQLContext` and `DataFrame`. In general these classes try to
+use types that are usable from both languages (i.e. `Array` instead of language-specific collections).
+In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading
+is used instead.
+
+Additionally, the Java specific types API has been removed. Users of both Scala and Java should
+use the classes present in `org.apache.spark.sql.types` to describe schema programmatically.
+
+
+#### Isolation of Implicit Conversions and Removal of dsl Package (Scala-only)
+{:.no_toc}
+
+Many of the code examples prior to Spark 1.3 started with `import sqlContext._`, which brought
+all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit
+conversions for converting `RDD`s into `DataFrame`s into an object inside of the `SQLContext`.
+Users should now write `import sqlContext.implicits._`.
+
+Additionally, the implicit conversions now only augment RDDs that are composed of `Product`s (i.e.,
+case classes or tuples) with a method `toDF`, instead of applying automatically.
+
+When using function inside of the DSL (now replaced with the `DataFrame` API) users used to import
+`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used:
+`import org.apache.spark.sql.functions._`.
+
+#### Removal of the type aliases in org.apache.spark.sql for DataType (Scala-only)
+{:.no_toc}
+
+Spark 1.3 removes the type aliases that were present in the base sql package for `DataType`. Users
+should instead import the classes in `org.apache.spark.sql.types`
+
+#### UDF Registration Moved to `sqlContext.udf` (Java & Scala)
+{:.no_toc}
+
+Functions that are used to register UDFs, either for use in the DataFrame DSL or SQL, have been
+moved into the udf object in `SQLContext`.
+
+
+
+{% highlight scala %}
+
+sqlContext.udf.register("strLen", (s: String) => s.length())
+
+{% endhighlight %}
+
+
+
+{% highlight java %}
+
+sqlContext.udf().register("strLen", (String s) -> s.length(), DataTypes.IntegerType);
+
+{% endhighlight %}
+
+
+
+
+Python UDF registration is unchanged.
+
+
+
+## Compatibility with Apache Hive
+
+Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs.
+Currently, Hive SerDes and UDFs are based on Hive 1.2.1,
+and Spark SQL can be connected to different versions of Hive Metastore
+(from 0.12.0 to 2.3.6 and 3.0.0 to 3.1.2. Also see [Interacting with Different Versions of Hive Metastore](sql-data-sources-hive-tables.html#interacting-with-different-versions-of-hive-metastore)).
+
+#### Deploying in Existing Hive Warehouses
+{:.no_toc}
+
+The Spark SQL Thrift JDBC server is designed to be "out of the box" compatible with existing Hive
+installations. You do not need to modify your existing Hive Metastore or change the data placement
+or partitioning of your tables.
+
+### Supported Hive Features
+{:.no_toc}
+
+Spark SQL supports the vast majority of Hive features, such as:
+
+* Hive query statements, including:
+ * `SELECT`
+ * `GROUP BY`
+ * `ORDER BY`
+ * `CLUSTER BY`
+ * `SORT BY`
+* All Hive operators, including:
+ * Relational operators (`=`, `⇔`, `==`, `<>`, `<`, `>`, `>=`, `<=`, etc)
+ * Arithmetic operators (`+`, `-`, `*`, `/`, `%`, etc)
+ * Logical operators (`AND`, `&&`, `OR`, `||`, etc)
+ * Complex type constructors
+ * Mathematical functions (`sign`, `ln`, `cos`, etc)
+ * String functions (`instr`, `length`, `printf`, etc)
+* User defined functions (UDF)
+* User defined aggregation functions (UDAF)
+* User defined serialization formats (SerDes)
+* Window functions
+* Joins
+ * `JOIN`
+ * `{LEFT|RIGHT|FULL} OUTER JOIN`
+ * `LEFT SEMI JOIN`
+ * `CROSS JOIN`
+* Unions
+* Sub-queries
+ * `SELECT col FROM ( SELECT a + b AS col from t1) t2`
+* Sampling
+* Explain
+* Partitioned tables including dynamic partition insertion
+* View
+ * If column aliases are not specified in view definition queries, both Spark and Hive will
+ generate alias names, but in different ways. In order for Spark to be able to read views created
+ by Hive, users should explicitly specify column aliases in view definition queries. As an
+ example, Spark cannot read `v1` created as below by Hive.
+
+ ```
+ CREATE VIEW v1 AS SELECT * FROM (SELECT c + 1 FROM (SELECT 1 c) t1) t2;
+ ```
+
+ Instead, you should create `v1` as below with column aliases explicitly specified.
+
+ ```
+ CREATE VIEW v1 AS SELECT * FROM (SELECT c + 1 AS inc_c FROM (SELECT 1 c) t1) t2;
+ ```
+
+* All Hive DDL Functions, including:
+ * `CREATE TABLE`
+ * `CREATE TABLE AS SELECT`
+ * `ALTER TABLE`
+* Most Hive Data types, including:
+ * `TINYINT`
+ * `SMALLINT`
+ * `INT`
+ * `BIGINT`
+ * `BOOLEAN`
+ * `FLOAT`
+ * `DOUBLE`
+ * `STRING`
+ * `BINARY`
+ * `TIMESTAMP`
+ * `DATE`
+ * `ARRAY<>`
+ * `MAP<>`
+ * `STRUCT<>`
+
+### Unsupported Hive Functionality
+{:.no_toc}
+
+Below is a list of Hive features that we don't support yet. Most of these features are rarely used
+in Hive deployments.
+
+**Major Hive Features**
+
+* Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL
+ doesn't support buckets yet.
+
+
+**Esoteric Hive Features**
+
+* `UNION` type
+* Unique join
+* Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at
+ the moment and only supports populating the sizeInBytes field of the hive metastore.
+
+**Hive Input/Output Formats**
+
+* File format for CLI: For results showing back to the CLI, Spark SQL only supports TextOutputFormat.
+* Hadoop archive
+
+**Hive Optimizations**
+
+A handful of Hive optimizations are not yet included in Spark. Some of these (such as indexes) are
+less important due to Spark SQL's in-memory computational model. Others are slotted for future
+releases of Spark SQL.
+
+* Block-level bitmap indexes and virtual columns (used to build indexes)
+* Automatically determine the number of reducers for joins and groupbys: Currently, in Spark SQL, you
+ need to control the degree of parallelism post-shuffle using "`SET spark.sql.shuffle.partitions=[num_tasks];`".
+* Meta-data only query: For queries that can be answered by using only metadata, Spark SQL still
+ launches tasks to compute the result.
+* Skew data flag: Spark SQL does not follow the skew data flags in Hive.
+* `STREAMTABLE` hint in join: Spark SQL does not follow the `STREAMTABLE` hint.
+* Merge multiple small files for query results: if the result output contains multiple small files,
+ Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS
+ metadata. Spark SQL does not support that.
+
+**Hive UDF/UDTF/UDAF**
+
+Not all the APIs of the Hive UDF/UDTF/UDAF are supported by Spark SQL. Below are the unsupported APIs:
+
+* `getRequiredJars` and `getRequiredFiles` (`UDF` and `GenericUDF`) are functions to automatically
+ include additional resources required by this UDF.
+* `initialize(StructObjectInspector)` in `GenericUDTF` is not supported yet. Spark SQL currently uses
+ a deprecated interface `initialize(ObjectInspector[])` only.
+* `configure` (`GenericUDF`, `GenericUDTF`, and `GenericUDAFEvaluator`) is a function to initialize
+ functions with `MapredContext`, which is inapplicable to Spark.
+* `close` (`GenericUDF` and `GenericUDAFEvaluator`) is a function to release associated resources.
+ Spark SQL does not call this function when tasks finish.
+* `reset` (`GenericUDAFEvaluator`) is a function to re-initialize aggregation for reusing the same aggregation.
+ Spark SQL currently does not support the reuse of aggregation.
+* `getWindowingEvaluator` (`GenericUDAFEvaluator`) is a function to optimize aggregation by evaluating
+ an aggregate over a fixed window.
+
+### Incompatible Hive UDF
+{:.no_toc}
+
+Below are the scenarios in which Hive and Spark generate different results:
+
+* `SQRT(n)` If n < 0, Hive returns null, Spark SQL returns NaN.
+* `ACOS(n)` If n < -1 or n > 1, Hive returns null, Spark SQL returns NaN.
+* `ASIN(n)` If n < -1 or n > 1, Hive returns null, Spark SQL returns NaN.
diff --git a/docs/mllib-migration-guides.md b/docs/sql-migration-old.md
similarity index 73%
rename from docs/mllib-migration-guides.md
rename to docs/sql-migration-old.md
index b746b96e19f07..e100820f6d664 100644
--- a/docs/mllib-migration-guides.md
+++ b/docs/sql-migration-old.md
@@ -1,7 +1,7 @@
---
layout: global
-title: Old Migration Guides - MLlib
-displayTitle: Old Migration Guides - MLlib
+title: Migration Guide
+displayTitle: Migration Guide
license: |
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
@@ -19,6 +19,5 @@ license: |
limitations under the License.
---
-The migration guide for the current Spark version is kept on the [MLlib Guide main page](ml-guide.html#migration-guide).
+The migration guide is now archived [on this page](sql-migration-guide.html).
-Past migration guides are now stored at [ml-migration-guides.html](ml-migration-guides.html).
diff --git a/docs/sql-ref-null-semantics.md b/docs/sql-ref-null-semantics.md
new file mode 100644
index 0000000000000..a67b3993a31c0
--- /dev/null
+++ b/docs/sql-ref-null-semantics.md
@@ -0,0 +1,703 @@
+---
+layout: global
+title: NULL Semantics
+displayTitle: NULL Semantics
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+### Description
+A table consists of a set of rows and each row contains a set of columns.
+A column is associated with a data type and represents
+a specific attribute of an entity (for example, `age` is a column of an
+entity called `person`). Sometimes, the value of a column
+specific to a row is not known at the time the row comes into existence.
+In `SQL`, such values are represnted as `NULL`. This section details the
+semantics of `NULL` values handling in various operators, expressions and
+other `SQL` constructs.
+
+1. [Null handling in comparison operators](#comp-operators)
+2. [Null handling in Logical operators](#logical-operators)
+3. [Null handling in Expressions](#expressions)
+ 1. [Null handling in null-in-tolerant expressions](#null-in-tolerant)
+ 2. [Null handling Expressions that can process null value operands](#can-process-null)
+ 3. [Null handling in built-in aggregate expressions](#built-in-aggregate)
+4. [Null handling in WHERE, HAVING and JOIN conditions](#condition-expressions)
+5. [Null handling in GROUP BY and DISTINCT](#aggregate-operator)
+6. [Null handling in ORDER BY](#order-by)
+7. [Null handling in UNION, INTERSECT, EXCEPT](#set-operators)
+8. [Null handling in EXISTS and NOT EXISTS subquery](#exists-not-exists)
+9. [Null handling in IN and NOT IN subquery](#in-not-in)
+
+
+
+The following illustrates the schema layout and data of a table named `person`. The data contains `NULL` values in
+the `age` column and this table will be used in various examples in the sections below.
+**TABLE: person **
+
+Id Name Age
+100 Joe 30
+200 Marry NULL
+300 Mike 18
+400 Fred 50
+500 Albert NULL
+600 Michelle 30
+700 Dan 50
+
+
+### Comparision operators
+
+Apache spark supports the standard comparison operators such as '>', '>=', '=', '<' and '<='.
+The result of these operators is unknown or `NULL` when one of the operarands or both the operands are
+unknown or `NULL`. In order to compare the `NULL` values for equality, Spark provides a null-safe
+equal operator ('<=>'), which returns `False` when one of the operand is `NULL` and returns 'True` when
+both the operands are `NULL`. The following table illustrates the behaviour of comparison operators when
+one or both operands are `NULL`:
+
+
+
+ Left Operand
+ Right Operand
+ >
+ >=
+ =
+ <
+ <=
+ <=>
+
+
+ NULL
+ Any value
+ NULL
+ NULL
+ NULL
+ NULL
+ NULL
+ False
+
+
+ Any value
+ NULL
+ NULL
+ NULL
+ NULL
+ NULL
+ NULL
+ False
+
+
+ NULL
+ NULL
+ NULL
+ NULL
+ NULL
+ NULL
+ NULL
+ True
+
+
+
+### Examples
+{% highlight sql %}
+-- Normal comparison operators return `NULL` when one of the operand is `NULL`.
+SELECT 5 > null AS expression_output;
+ +-----------------+
+ |expression_output|
+ +-----------------+
+ |null |
+ +-----------------+
+
+-- Normal comparison operators return `NULL` when both the operands are `NULL`.
+SELECT null = null AS expression_output;
+ +-----------------+
+ |expression_output|
+ +-----------------+
+ |null |
+ +-----------------+
+
+-- Null-safe equal operator return `False` when one of the operand is `NULL`
+SELECT 5 <=> null AS expression_output;
+ +-----------------+
+ |expression_output|
+ +-----------------+
+ |false |
+ +-----------------+
+
+-- Null-safe equal operator return `True` when one of the operand is `NULL`
+SELECT NULL <=> NULL;
+ +-----------------+
+ |expression_output|
+ +-----------------+
+ |true |
+ +-----------------+
+{% endhighlight %}
+
+### Logical operators
+Spark supports standard logical operators such as `AND`, `OR` and `NOT`. These operators take `Boolean` expressions
+as the arguments and return a `Boolean` value.
+
+The following tables illustrate the behavior of logical opeators when one or both operands are `NULL`.
+
+
+
+ Left Operand
+ Right Operand
+ OR
+ AND
+
+
+ True
+ NULL
+ True
+ NULL
+
+
+ False
+ NULL
+ NULL
+ False
+
+
+ NULL
+ True
+ True
+ NULL
+
+
+ NULL
+ False
+ NULL
+ NULL
+
+
+ NULL
+ NULL
+ NULL
+ NULL
+
+
+
+
+
+ operand
+ NOT
+
+
+ NULL
+ NULL
+
+
+
+### Examples
+{% highlight sql %}
+-- Normal comparison operators return `NULL` when one of the operands is `NULL`.
+SELECT (true OR null) AS expression_output;
+ +-----------------+
+ |expression_output|
+ +-----------------+
+ |true |
+ +-----------------+
+
+-- Normal comparison operators return `NULL` when both the operands are `NULL`.
+SELECT (null OR false) AS expression_output
+ +-----------------+
+ |expression_output|
+ +-----------------+
+ |null |
+ +-----------------+
+
+-- Null-safe equal operator returns `False` when one of the operands is `NULL`
+SELECT NOT(null) AS expression_output;
+ +-----------------+
+ |expression_output|
+ +-----------------+
+ |null |
+ +-----------------+
+{% endhighlight %}
+
+### Expressions
+The comparison operators and logical operators are treated as expressions in
+Spark. Other than these two kinds of expressions, Spark supports other form of
+expressions such as function expressions, cast expressions, etc. The expressions
+in Spark can be broadly classified as :
+- Null in-tolerent expressions
+- Expressions that can process `NULL` value operands
+ - The result of these expressions depends on the expression itself.
+
+#### Null in-tolerant expressions
+Null in-tolerant expressions return `NULL` when one or more arguments of
+expression are `NULL` and most of the expressions fall in this category.
+
+##### Examples
+{% highlight sql %}
+SELECT concat('John', null) as expression_output;
+ +-----------------+
+ |expression_output|
+ +-----------------+
+ |null |
+ +-----------------+
+
+SELECT positive(null) as expression_output;
+ +-----------------+
+ |expression_output|
+ +-----------------+
+ |null |
+ +-----------------+
+
+SELECT to_date(null) as expression_output;
+ +-----------------+
+ |expression_output|
+ +-----------------+
+ |null |
+ +-----------------+
+{% endhighlight %}
+
+#### Expressions that can process null value operands.
+
+This class of expressions are designed to handle `NULL` values. The result of the
+expressions depends on the expression itself. As an example, function expression `isnull`
+returns a `true` on null input and `false` on non null input where as function `coalesce`
+returns the first non `NULL` value in its list of operands. However, `coalesce` returns
+`NULL` when all its operands are `NULL`. Below is an incomplete list of expressions of this category.
+ - COALESCE
+ - NULLIF
+ - IFNULL
+ - NVL
+ - NVL2
+ - ISNAN
+ - NANVL
+ - ISNULL
+ - ISNOTNULL
+ - ATLEASTNNONNULLS
+ - IN
+
+
+##### Examples
+{% highlight sql %}
+SELECT isnull(null) AS expression_output;
+ +-----------------+
+ |expression_output|
+ +-----------------+
+ |true |
+ +-----------------+
+
+-- Returns the first occurence of non `NULL` value.
+SELECT coalesce(null, null, 3, null) AS expression_output;
+ +-----------------+
+ |expression_output|
+ +-----------------+
+ |3 |
+ +-----------------+
+
+-- Returns `NULL` as all its operands are `NULL`.
+SELECT coalesce(null, null, null, null) AS expression_output;
+ +-----------------+
+ |expression_output|
+ +-----------------+
+ |null |
+ +-----------------+
+
+SELECT isnan(null) as expression_output;
+ +-----------------+
+ |expression_output|
+ +-----------------+
+ |false |
+ +-----------------+
+{% endhighlight %}
+
+#### Builtin Aggregate Expressions
+Aggregate functions compute a single result by processing a set of input rows. Below are
+the rules of how `NULL` values are handled by aggregate functions.
+- `NULL` values are ignored from processing by all the aggregate functions.
+ - Only exception to this rule is COUNT(*) function.
+- Some aggregate functions return `NULL` when all input values are `NULL` or the input data set
+ is empty. The list of these functions is:
+ - MAX
+ - MIN
+ - SUM
+ - AVG
+ - EVERY
+ - ANY
+ - SOME
+
+#### Examples
+{% highlight sql %}
+-- `count(*)` does not skip `NULL` values.
+SELECT count(*) FROM person;
+ +--------+
+ |count(1)|
+ +--------+
+ |7 |
+ +--------+
+
+-- `NULL` values in column `age` are skipped from processing.
+SELECT count(age) FROM person;
+ +----------+
+ |count(age)|
+ +----------+
+ |5 |
+ +----------+
+
+-- `count(*)` on an empty input set returns 0. This is unlike the other
+-- aggregate functions, such as `max`, which return `NULL`.
+SELECT count(*) FROM person where 1 = 0;
+ +--------+
+ |count(1)|
+ +--------+
+ |0 |
+ +--------+
+
+-- `NULL` values are excluded from computation of maximum value.
+SELECT max(age) FROM person;
+ +--------+
+ |max(age)|
+ +--------+
+ |50 |
+ +--------+
+
+-- `max` returns `NULL` on an empty input set.
+SELECT max(age) FROM person where 1 = 0;
+ +--------+
+ |max(age)|
+ +--------+
+ |null |
+ +--------+
+
+{% endhighlight %}
+
+### Condition expressions in WHERE, HAVING and JOIN clauses.
+`WHERE`, `HAVING` operators filter rows based on the user specified condition.
+A `JOIN` operator is used to combine rows from two tables based on a join condition.
+For all the three operators, a condition expression is a boolean expression and can return
+ True, False or Unknown (NULL)
. They are "satisfied" if the result of the condition is `True`.
+
+#### Examples
+{% highlight sql %}
+-- Persons whose age is unknown (`NULL`) are filtered out from the result set.
+SELECT * FROM person WHERE age > 0;
+ +--------+---+
+ |name |age|
+ +--------+---+
+ |Michelle|30 |
+ |Fred |50 |
+ |Mike |18 |
+ |Dan |50 |
+ |Joe |30 |
+ +--------+---+
+
+-- `IS NULL` expression is used in disjunction to select the persons
+-- with unknown (`NULL`) records.
+SELECT * FROM person WHERE age > 0 OR age IS NULL;
+ +--------+----+
+ |name |age |
+ +--------+----+
+ |Albert |null|
+ |Michelle|30 |
+ |Fred |50 |
+ |Mike |18 |
+ |Dan |50 |
+ |Marry |null|
+ |Joe |30 |
+ +--------+----+
+
+-- Person with unknown(`NULL`) ages are skipped from processing.
+SELECT * FROM person GROUP BY age HAVING max(age) > 18;
+ +---+--------+
+ |age|count(1)|
+ +---+--------+
+ |50 |2 |
+ |30 |2 |
+ +---+--------+
+
+-- A self join case with a join condition `p1.age = p2.age AND p1.name = p2.name`.
+-- The persons with unknown age (`NULL`) are filtered out by the join operator.
+SELECT * FROM person p1, person p2
+WHERE p1.age = p2.age
+ AND p1.name = p2.name;
+ +--------+---+--------+---+
+ |name |age|name |age|
+ +--------+---+--------+---+
+ |Michelle|30 |Michelle|30 |
+ |Fred |50 |Fred |50 |
+ |Mike |18 |Mike |18 |
+ |Dan |50 |Dan |50 |
+ |Joe |30 |Joe |30 |
+ +--------+---+--------+---+
+
+-- The age column from both legs of join are compared using null-safe equal which
+-- is why the persons with unknown age (`NULL`) are qualified by the join.
+SELECT * FROM person p1, person p2
+WHERE p1.age <=> p2.age
+ AND p1.name = p2.name;
++--------+----+--------+----+
+| name| age| name| age|
++--------+----+--------+----+
+| Albert|null| Albert|null|
+|Michelle| 30|Michelle| 30|
+| Fred| 50| Fred| 50|
+| Mike| 18| Mike| 18|
+| Dan| 50| Dan| 50|
+| Marry|null| Marry|null|
+| Joe| 30| Joe| 30|
++--------+----+--------+----+
+
+{% endhighlight %}
+
+### Aggregate operator (GROUP BY, DISTINCT)
+As discussed in the previous section [comparison operator](sql-ref-null-semantics.html#comparision-operators),
+two `NULL` values are not equal. However, for the purpose of grouping and distinct processing, the two or more
+values with `NULL data`are grouped together into the same bucket. This behaviour is conformant with SQL
+standard and with other enterprise database management systems.
+
+#### Examples
+{% highlight sql %}
+-- `NULL` values are put in one bucket in `GROUP BY` processing.
+SELECT age, count(*) FROM person GROUP BY age;
+ +----+--------+
+ |age |count(1)|
+ +----+--------+
+ |null|2 |
+ |50 |2 |
+ |30 |2 |
+ |18 |1 |
+ +----+--------+
+
+-- All `NULL` ages are considered one distinct value in `DISTINCT` processing.
+SELECT DISTINCT age FROM person;
+ +----+
+ |age |
+ +----+
+ |null|
+ |50 |
+ |30 |
+ |18 |
+ +----+
+
+{% endhighlight %}
+
+### Sort operator (ORDER BY Clause)
+Spark SQL supports null ordering specification in `ORDER BY` clause. Spark processes the `ORDER BY` clause by
+placing all the `NULL` values at first or at last depending on the null ordering specification. By default, all
+the `NULL` values are placed at first.
+
+#### Examples
+{% highlight sql %}
+-- `NULL` values are shown at first and other values
+-- are sorted in ascending way.
+SELECT age, name FROM person ORDER BY age;
+ +----+--------+
+ |age |name |
+ +----+--------+
+ |null|Marry |
+ |null|Albert |
+ |18 |Mike |
+ |30 |Michelle|
+ |30 |Joe |
+ |50 |Fred |
+ |50 |Dan |
+ +----+--------+
+
+-- Column values other than `NULL` are sorted in ascending
+-- way and `NULL` values are shown at the last.
+SELECT age, name FROM person ORDER BY age NULLS LAST;
+ +----+--------+
+ |age |name |
+ +----+--------+
+ |18 |Mike |
+ |30 |Michelle|
+ |30 |Joe |
+ |50 |Dan |
+ |50 |Fred |
+ |null|Marry |
+ |null|Albert |
+ +----+--------+
+
+-- Columns other than `NULL` values are sorted in descending
+-- and `NULL` values are shown at the last.
+SELECT age, name FROM person ORDER BY age DESC NULLS LAST;
+ +----+--------+
+ |age |name |
+ +----+--------+
+ |50 |Fred |
+ |50 |Dan |
+ |30 |Michelle|
+ |30 |Joe |
+ |18 |Mike |
+ |null|Marry |
+ |null|Albert |
+ +----+--------+
+{% endhighlight %}
+
+### Set operators (UNION, INTERSECT, EXCEPT)
+`NULL` values are compared in a null-safe manner for equality in the context of
+set operations. That means when comparing rows, two `NULL` values are considered
+equal unlike the regular `EqualTo`(`=`) operator.
+
+#### Examples
+{% highlight sql %}
+CREATE VIEW unknown_age SELECT * FROM person WHERE age IS NULL;
+
+-- Only common rows between two legs of `INTERSECT` are in the
+-- result set. The comparison between columns of the row are done
+-- in a null-safe manner.
+SELECT name, age FROM person
+INTERSECT
+SELECT name, age from unknown_age;
+ +------+----+
+ |name |age |
+ +------+----+
+ |Albert|null|
+ |Marry |null|
+ +------+----+
+
+-- `NULL` values from two legs of the `EXCEPT` are not in output.
+-- This basically shows that the comparison happens in a null-safe manner.
+SELECT age, name FROM person
+EXCEPT
+SELECT age FROM unknown_age;
+ +---+--------+
+ |age|name |
+ +---+--------+
+ |30 |Joe |
+ |50 |Fred |
+ |30 |Michelle|
+ |18 |Mike |
+ |50 |Dan |
+ +---+--------+
+
+-- Performs `UNION` operation between two sets of data.
+-- The comparison between columns of the row ae done in
+-- null-safe manner.
+SELECT name, age FROM person
+UNION
+SELECT name, age FROM unknown_age;
+ +--------+----+
+ |name |age |
+ +--------+----+
+ |Albert |null|
+ |Joe |30 |
+ |Michelle|30 |
+ |Marry |null|
+ |Fred |50 |
+ |Mike |18 |
+ |Dan |50 |
+ +--------+----+
+{% endhighlight %}
+
+
+### EXISTS/NOT EXISTS Subquery
+In Spark, EXISTS and NOT EXISTS expressions are allowed inside a WHERE clause.
+These are boolean expressions which return either `TRUE` or
+`FALSE`. In otherwords, EXISTS is a membership condition and returns `TRUE`
+when the subquery it refers to returns one or more rows. Similary, NOT EXISTS
+is a non-membership condition and returns TRUE when no rows or zero rows are
+returned from the subquery.
+
+These two expressions are not affected by presence of NULL in the result of
+the subquery.
+
+#### Examples
+{% highlight sql %}
+-- Even if subquery produces rows with `NULL` values, the `EXISTS` expression
+-- evaluates to `TRUE` as the subquery produces 1 row.
+SELECT * FROM person WHERE EXISTS (SELECT null);
+ +--------+----+
+ |name |age |
+ +--------+----+
+ |Albert |null|
+ |Michelle|30 |
+ |Fred |50 |
+ |Mike |18 |
+ |Dan |50 |
+ |Marry |null|
+ |Joe |30 |
+ +--------+----+
+
+-- `NOT EXISTS` expression returns `FALSE`. It returns `TRUE` only when
+-- subquery produces no rows. In this case, it returns 1 row.
+SELECT * FROM person WHERE NOT EXISTS (SELECT null);
+ +----+---+
+ |name|age|
+ +----+---+
+ +----+---+
+
+-- `NOT EXISTS` expression returns `TRUE`.
+SELECT * FROM person WHERE NOT EXISTS (SELECT 1 WHERE 1 = 0);
+ +--------+----+
+ |name |age |
+ +--------+----+
+ |Albert |null|
+ |Michelle|30 |
+ |Fred |50 |
+ |Mike |18 |
+ |Dan |50 |
+ |Marry |null|
+ |Joe |30 |
+ +--------+----+
+{% endhighlight %}
+
+### IN/NOT IN Subquery
+In Spark, `IN` and `NOT IN` expressions are allowed inside a WHERE clause of
+a query. Unlike the `EXISTS` expression, `IN` expression can return a `TRUE`,
+`FALSE` or `UNKNOWN (NULL)` value. Conceptually a `IN` expression is semantically
+equivalent to a set of equality condition separated by a disjunctive operator (`OR`).
+For example, c1 IN (1, 2, 3) is semantically equivalent to `(C1 = 1 OR c1 = 2 OR c1 = 3)`.
+
+As far as handling `NULL` values are concerned, the semantics can be deduced from
+the `NULL` value handling in comparison operators(`=`) and logical operators(`OR`).
+To summarize, below are the rules for computing the result of an `IN` expression.
+
+- TRUE is returned when the non-NULL value in question is found in the list
+- FALSE is returned when the non-NULL value is not found in the list and the
+ list does not contain NULL values
+- UNKNOWN is returned when the value is `NULL`, or the non-NULL value is not found in the list
+ and the list contains at least one `NULL` value
+
+#### Examples
+{% highlight sql %}
+-- The subquery has only `NULL` value in its result set. Therefore,
+-- the result of `IN` predicate is UNKNOWN.
+SELECT * FROM person WHERE age IN (SELECT null);
+ +----+---+
+ |name|age|
+ +----+---+
+ +----+---+
+
+-- The subquery has `NULL` value in the result set as well as a valid
+-- value `50`. Rows with age = 50 are returned.
+SELECT * FROM person
+WHERE age IN (SELECT age FROM VALUES (50), (null) sub(age));
+ +----+---+
+ |name|age|
+ +----+---+
+ |Fred|50 |
+ |Dan |50 |
+ +----+---+
+
+-- Since subquery has `NULL` value in the result set, the `NOT IN`
+-- predicate would return UNKNOWN. Hence, no rows are
+-- qualified for this query.
+SELECT * FROM person
+WHERE age NOT IN (SELECT age FROM VALUES (50), (null) sub(age));
+ +----+---+
+ |name|age|
+ +----+---+
+ +----+---+
+
+{% endhighlight %}
diff --git a/docs/sql-ref-syntax-aux-cache-clear-cache.md b/docs/sql-ref-syntax-aux-cache-clear-cache.md
index 88d126f0f528e..d8e451a230a71 100644
--- a/docs/sql-ref-syntax-aux-cache-clear-cache.md
+++ b/docs/sql-ref-syntax-aux-cache-clear-cache.md
@@ -19,4 +19,20 @@ license: |
limitations under the License.
---
-**This page is under construction**
+### Description
+`CLEAR CACHE` removes the entries and associated data from the in-memory and/or on-disk cache for all cached tables and views.
+
+### Syntax
+{% highlight sql %}
+CLEAR CACHE
+{% endhighlight %}
+
+### Examples
+{% highlight sql %}
+CLEAR CACHE;
+{% endhighlight %}
+
+### Related Statements
+ * [CACHE TABLE](sql-ref-syntax-aux-cache-cache-table.html)
+ * [UNCACHE TABLE](sql-ref-syntax-aux-cache-uncache-table.html)
+
diff --git a/docs/sql-ref-syntax-aux-cache-uncache-table.md b/docs/sql-ref-syntax-aux-cache-uncache-table.md
index 69819fee088da..a6cb4d6807b22 100644
--- a/docs/sql-ref-syntax-aux-cache-uncache-table.md
+++ b/docs/sql-ref-syntax-aux-cache-uncache-table.md
@@ -20,7 +20,7 @@ license: |
---
### Description
-`UNCACHE TABLE` removes the entries and associated data from the in-memory and/or on-disk cache for a given table. The
+`UNCACHE TABLE` removes the entries and associated data from the in-memory and/or on-disk cache for a given table or view. The
underlying entries should already have been brought to cache by previous `CACHE TABLE` operation. `UNCACHE TABLE` on a non-existent table throws Exception if `IF EXISTS` is not specified.
### Syntax
{% highlight sql %}
@@ -29,7 +29,7 @@ UNCACHE TABLE [ IF EXISTS ] table_name
### Parameters
table_name
- The name of the table to be uncached.
+ The name of the table or view to be uncached.
### Examples
{% highlight sql %}
diff --git a/docs/sql-ref-syntax-aux-cache.md b/docs/sql-ref-syntax-aux-cache.md
index eb0e73d00e848..c3dcb276a7e0f 100644
--- a/docs/sql-ref-syntax-aux-cache.md
+++ b/docs/sql-ref-syntax-aux-cache.md
@@ -1,7 +1,7 @@
---
layout: global
-title: Reference
-displayTitle: Reference
+title: Cache
+displayTitle: Cache
license: |
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
@@ -9,9 +9,9 @@ license: |
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
-
+
http://www.apache.org/licenses/LICENSE-2.0
-
+
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -19,7 +19,6 @@ license: |
limitations under the License.
---
-Spark SQL is a Apache Spark's module for working with structured data.
-This guide is a reference for Structured Query Language (SQL) for Apache
-Spark. This document describes the SQL constructs supported by Spark in detail
-along with usage examples when applicable.
+* [CACHE TABLE statement](sql-ref-syntax-aux-cache-cache-table.html)
+* [UNCACHE TABLE statement](sql-ref-syntax-aux-cache-uncache-table.html)
+* [CLEAR CACHE statement](sql-ref-syntax-aux-cache-clear-cache.html)
diff --git a/docs/sql-ref-syntax-aux-conf-mgmt-reset.md b/docs/sql-ref-syntax-aux-conf-mgmt-reset.md
index ad2d7f9a83316..8ee61514ee4ef 100644
--- a/docs/sql-ref-syntax-aux-conf-mgmt-reset.md
+++ b/docs/sql-ref-syntax-aux-conf-mgmt-reset.md
@@ -19,4 +19,20 @@ license: |
limitations under the License.
---
-**This page is under construction**
+### Description
+Reset all the properties specific to the current session to their default values. After RESET command, executing SET command will output empty.
+
+### Syntax
+{% highlight sql %}
+RESET
+{% endhighlight %}
+
+
+### Examples
+{% highlight sql %}
+-- Reset all the properties specific to the current session to their default values.
+RESET;
+{% endhighlight %}
+
+### Related Statements
+- [SET](sql-ref-syntax-aux-conf-mgmt-set.html)
diff --git a/docs/sql-ref-syntax-aux-conf-mgmt-set.md b/docs/sql-ref-syntax-aux-conf-mgmt-set.md
index c38d68dbb4f1d..f05dde3f567ee 100644
--- a/docs/sql-ref-syntax-aux-conf-mgmt-set.md
+++ b/docs/sql-ref-syntax-aux-conf-mgmt-set.md
@@ -19,4 +19,51 @@ license: |
limitations under the License.
---
-**This page is under construction**
+### Description
+The SET command sets a property, returns the value of an existing property or returns all SQLConf properties with value and meaning.
+
+### Syntax
+{% highlight sql %}
+SET
+SET [ -v ]
+SET property_key[ = property_value ]
+{% endhighlight %}
+
+### Parameters
+
+ -v
+ Outputs the key, value and meaning of existing SQLConf properties.
+
+
+
+ property_key
+ Returns the value of specified property key.
+
+
+
+ property_key=property_value
+ Sets the value for a given property key. If an old value exists for a given property key, then it gets overridden by the new value.
+
+
+### Examples
+{% highlight sql %}
+-- Set a property.
+SET spark.sql.variable.substitute=false;
+
+-- List all SQLConf properties with value and meaning.
+SET -v;
+
+-- List all SQLConf properties with value for current session.
+SET;
+
+-- List the value of specified property key.
+SET spark.sql.variable.substitute;
+ +--------------------------------+--------+
+ | key | value |
+ +--------------------------------+--------+
+ | spark.sql.variable.substitute | false |
+ +--------------------------------+--------+
+{% endhighlight %}
+
+### Related Statements
+- [RESET](sql-ref-syntax-aux-conf-mgmt-reset.html)
diff --git a/docs/sql-ref-syntax-aux-refresh-table.md b/docs/sql-ref-syntax-aux-refresh-table.md
new file mode 100644
index 0000000000000..262382a467073
--- /dev/null
+++ b/docs/sql-ref-syntax-aux-refresh-table.md
@@ -0,0 +1,58 @@
+---
+layout: global
+title: REFRESH TABLE
+displayTitle: REFRESH TABLE
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+### Description
+`REFRESH TABLE` statement invalidates the cached entries, which include data
+and metadata of the given table or view. The invalidated cache is populated in
+lazy manner when the cached table or the query associated with it is executed again.
+
+### Syntax
+{% highlight sql %}
+REFRESH [TABLE] tableIdentifier
+{% endhighlight %}
+
+### Parameters
+
+ tableIdentifier
+
+ Specifies a table name, which is either a qualified or unqualified name that designates a table/view. If no database identifier is provided, it refers to a temporary view or a table/view in the current database.
+ Syntax:
+
+ [database_name.]table_name
+
+
+
+
+### Examples
+{% highlight sql %}
+-- The cached entries of the table will be refreshed
+-- The table is resolved from the current database as the table name is unqualified.
+REFRESH TABLE tbl1;
+
+-- The cached entries of the view will be refreshed or invalidated
+-- The view is resolved from tempDB database, as the view name is qualified.
+REFRESH TABLE tempDB.view1;
+{% endhighlight %}
+
+### Related Statements
+- [CACHE TABLE](sql-ref-syntax-aux-cache-cache-table.html)
+- [CLEAR CACHE](sql-ref-syntax-aux-cache-clear-cache.html)
+- [UNCACHE TABLE](sql-ref-syntax-aux-cache-uncache-table.html)
\ No newline at end of file
diff --git a/docs/sql-ref-syntax-aux-show-create-table.md b/docs/sql-ref-syntax-aux-show-create-table.md
index 2cf40915774c4..7871d30b5b186 100644
--- a/docs/sql-ref-syntax-aux-show-create-table.md
+++ b/docs/sql-ref-syntax-aux-show-create-table.md
@@ -19,4 +19,46 @@ license: |
limitations under the License.
---
-**This page is under construction**
+### Description
+`SHOW CREATE TABLE` returns the [CREATE TABLE statement](sql-ref-syntax-ddl-create-table.html) or [CREATE VIEW statement](sql-ref-syntax-ddl-create-view.html) that was used to create a given table or view. `SHOW CREATE TABLE` on a non-existent table or a temporary view throws an exception.
+
+### Syntax
+{% highlight sql %}
+SHOW CREATE TABLE name
+{% endhighlight %}
+
+### Parameters
+
+ name
+ The name of the table or view to be used for SHOW CREATE TABLE.
+
+
+### Examples
+{% highlight sql %}
+CREATE TABLE test (c INT) ROW FORMAT DELIMITED FIELDS TERMINATED BY ','
+ STORED AS TEXTFILE
+ TBLPROPERTIES ('prop1' = 'value1', 'prop2' = 'value2');
+
+show create table test;
+
+-- the result of SHOW CREATE TABLE test
+CREATE TABLE `test`(`c` INT)
+ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
+WITH SERDEPROPERTIES (
+ 'field.delim' = ',',
+ 'serialization.format' = ','
+)
+STORED AS
+ INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat'
+ OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'
+TBLPROPERTIES (
+ 'transient_lastDdlTime' = '1569350233',
+ 'prop1' = 'value1',
+ 'prop2' = 'value2'
+)
+
+{% endhighlight %}
+
+### Related Statements
+ * [CREATE TABLE](sql-ref-syntax-ddl-create-table.html)
+ * [CREATE VIEW](sql-ref-syntax-ddl-create-view.html)
diff --git a/docs/sql-ref-syntax-ddl-create-database.md b/docs/sql-ref-syntax-ddl-create-database.md
index bbcd34a6d6853..ed0bbf629b027 100644
--- a/docs/sql-ref-syntax-ddl-create-database.md
+++ b/docs/sql-ref-syntax-ddl-create-database.md
@@ -19,4 +19,61 @@ license: |
limitations under the License.
---
-**This page is under construction**
+### Description
+Creates a database with the specified name. If database with the same name already exists, an exception will be thrown.
+
+### Syntax
+{% highlight sql %}
+CREATE {DATABASE | SCHEMA} [ IF NOT EXISTS ] database_name
+ [ COMMENT database_comment ]
+ [ LOCATION database_directory ]
+ [ WITH DBPROPERTIES (property_name=property_value [ , ...]) ]
+{% endhighlight %}
+
+### Parameters
+
+ database_name
+ Specifies the name of the database to be created.
+
+ IF NOT EXISTS
+ Creates a database with the given name if it doesn't exists. If a database with the same name already exists, nothing will happen.
+
+ database_directory
+ Path of the file system in which the specified database is to be created. If the specified path does not exist in the underlying file system, this command creates a directory with the path. If the location is not specified, the database will be created in the default warehouse directory, whose path is configured by the static configuration spark.sql.warehouse.dir.
+
+ database_comment
+ Specifies the description for the database.
+
+ WITH DBPROPERTIES (property_name=property_value [ , ...])
+ Specifies the properties for the database in key-value pairs.
+
+
+### Examples
+{% highlight sql %}
+-- Create database `customer_db`. This throws exception if database with name customer_db
+-- already exists.
+CREATE DATABASE customer_db;
+
+-- Create database `customer_db` only if database with same name doesn't exist.
+CREATE DATABASE IF NOT EXISTS customer_db;
+
+-- Create database `customer_db` only if database with same name doesn't exist with
+-- `Comments`,`Specific Location` and `Database properties`.
+CREATE DATABASE IF NOT EXISTS customer_db COMMENT 'This is customer database' LOCATION '/user'
+ WITH DBPROPERTIES (ID=001, Name='John');
+
+-- Verify that properties are set.
+DESCRIBE DATABASE EXTENDED customer_db;
+ +----------------------------+-----------------------------+
+ | database_description_item | database_description_value |
+ +----------------------------+-----------------------------+
+ | Database Name | customer_db |
+ | Description | This is customer database |
+ | Location | hdfs://hacluster/user |
+ | Properties | ((ID,001), (Name,John)) |
+ +----------------------------+-----------------------------+
+{% endhighlight %}
+
+### Related Statements
+- [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html)
+- [DROP DATABASE](sql-ref-syntax-ddl-drop-database.html)
diff --git a/docs/sql-ref-syntax-ddl-create-view.md b/docs/sql-ref-syntax-ddl-create-view.md
index eff7df91f59c5..c7ca28ea5b62f 100644
--- a/docs/sql-ref-syntax-ddl-create-view.md
+++ b/docs/sql-ref-syntax-ddl-create-view.md
@@ -19,4 +19,64 @@ license: |
limitations under the License.
---
-**This page is under construction**
+### Description
+Views are based on the result-set of an `SQL` query. `CREATE VIEW` constructs
+a virtual table that has no physical data therefore other operations like
+`ALTER VIEW` and `DROP VIEW` only change metadata.
+
+### Syntax
+{% highlight sql %}
+CREATE [OR REPLACE] [[GLOBAL] TEMPORARY] VIEW [IF NOT EXISTS] [db_name.]view_name
+ create_view_clauses
+ AS query;
+{% endhighlight %}
+
+### Parameters
+
+ OR REPLACE
+ If a view of same name already exists, it will be replaced.
+
+
+ [GLOBAL] TEMPORARY
+ TEMPORARY views are session-scoped and will be dropped when session ends
+ because it skips persisting the definition in the underlying metastore, if any.
+ GLOBAL TEMPORARY views are tied to a system preserved temporary database `global_temp`.
+
+
+ IF NOT EXISTS
+ Creates a view if it does not exists.
+
+
+ create_view_clauses
+ These clauses are optional and order insensitive. It can be of following formats.
+
+ [(column_name [COMMENT column_comment], ...) ]
to specify column-level comments.
+ [COMMENT view_comment]
to specify view-level comments.
+ [TBLPROPERTIES (property_name = property_value, ...)]
to add metadata key-value pairs.
+
+
+
+
+ query
+ A SELECT statement that constructs the view from base tables or other views.
+
+
+### Examples
+{% highlight sql %}
+-- Create or replace view for `experienced_employee` with comments.
+CREATE OR REPLACE VIEW experienced_employee
+ (ID COMMENT 'Unique identification number', Name)
+ COMMENT 'View for experienced employees'
+ AS SELECT id, name FROM all_employee
+ WHERE working_years > 5;
+
+-- Create a global temporary view `subscribed_movies` if it does not exist.
+CREATE GLOBAL TEMPORARY VIEW IF NOT EXISTS subscribed_movies
+ AS SELECT mo.member_id, mb.full_name, mo.movie_title
+ FROM movies AS mo INNER JOIN members AS mb
+ ON mo.member_id = mb.id;
+{% endhighlight %}
+
+### Related Statements
+- [ALTER VIEW](sql-ref-syntax-ddl-alter-view.md)
+- [DROP VIEW](sql-ref-syntax-ddl-drop-view.md)
diff --git a/docs/sql-ref-syntax-ddl-drop-database.md b/docs/sql-ref-syntax-ddl-drop-database.md
index cd900a7e393db..f3cdbf91a8d2a 100644
--- a/docs/sql-ref-syntax-ddl-drop-database.md
+++ b/docs/sql-ref-syntax-ddl-drop-database.md
@@ -19,4 +19,62 @@ license: |
limitations under the License.
---
-**This page is under construction**
+### Description
+
+Drop a database and delete the directory associated with the database from the file system. An
+exception will be thrown if the database does not exist in the system.
+
+### Syntax
+
+{% highlight sql %}
+DROP (DATABASE|SCHEMA) [IF EXISTS] dbname [RESTRICT|CASCADE];
+{% endhighlight %}
+
+
+### Parameters
+
+
+ DATABASE|SCHEMA
+ `DATABASE` and `SCHEMA` mean the same thing, either of them can be used.
+
+
+
+ IF EXISTS
+ If specified, no exception is thrown when the database does not exist.
+
+
+
+ RESTRICT
+ If specified, will restrict dropping a non-empty database and is enabled by default.
+
+
+
+ CASCADE
+ If specified, will drop all the associated tables and functions.
+
+
+### Example
+{% highlight sql %}
+-- Create `inventory_db` Database
+CREATE DATABASE inventory_db COMMENT 'This database is used to maintain Inventory';
+
+-- Drop the database and it's tables
+DROP DATABASE inventory_db CASCADE;
++---------+
+| Result |
++---------+
++---------+
+
+-- Drop the database using IF EXISTS
+DROP DATABASE IF EXISTS inventory_db CASCADE;
++---------+
+| Result |
++---------+
++---------+
+
+{% endhighlight %}
+
+### Related statements
+- [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html)
+- [DESCRIBE DATABASE](sql-ref-syntax-aux-describe-database.html)
+- [SHOW DATABASES](sql-ref-syntax-aux-show-databases.html)
\ No newline at end of file
diff --git a/docs/sql-ref-syntax-ddl-truncate-table.md b/docs/sql-ref-syntax-ddl-truncate-table.md
index 2704259391e94..4b4094ab708e5 100644
--- a/docs/sql-ref-syntax-ddl-truncate-table.md
+++ b/docs/sql-ref-syntax-ddl-truncate-table.md
@@ -19,4 +19,68 @@ license: |
limitations under the License.
---
-**This page is under construction**
+### Description
+The `TRUNCATE TABLE` statement removes all the rows from a table or partition(s). The table must not be a view
+or an external/temporary table. In order to truncate multiple partitions at once, the user can specify the partitions
+in `partition_spec`. If no `partition_spec` is specified it will remove all partitions in the table.
+
+### Syntax
+{% highlight sql %}
+TRUNCATE TABLE table_name [PARTITION partition_spec];
+{% endhighlight %}
+
+### Parameters
+
+ table_name
+ The name of an existing table.
+
+
+
+ PARTITION ( partition_spec :[ partition_column = partition_col_value, partition_column = partition_col_value, ...] )
+ Specifies one or more partition column and value pairs. The partition value is optional.
+
+
+
+### Examples
+{% highlight sql %}
+
+--Create table Student with partition
+CREATE TABLE Student ( name String, rollno INT) PARTITIONED BY (age int);
+
+SELECT * from Student;
++-------+---------+------+--+
+| name | rollno | age |
++-------+---------+------+--+
+| ABC | 1 | 10 |
+| DEF | 2 | 10 |
+| XYZ | 3 | 12 |
++-------+---------+------+--+
+
+-- Removes all rows from the table in the partion specified
+TRUNCATE TABLE Student partition(age=10);
+
+--After truncate execution, records belonging to partition age=10 are removed
+SELECT * from Student;
++-------+---------+------+--+
+| name | rollno | age |
++-------+---------+------+--+
+| XYZ | 3 | 12 |
++-------+---------+------+--+
+
+-- Removes all rows from the table from all partitions
+TRUNCATE TABLE Student;
+
+SELECT * from Student;
++-------+---------+------+--+
+| name | rollno | age |
++-------+---------+------+--+
++-------+---------+------+--+
+No rows selected
+
+{% endhighlight %}
+
+
+### Related Statements
+- [DROP TABLE](sql-ref-syntax-ddl-drop-table.html)
+- [ALTER TABLE](sql-ref-syntax-ddl-alter-tabley.html)
+
diff --git a/docs/sql-ref-syntax-qry-select-usedb.md b/docs/sql-ref-syntax-qry-select-usedb.md
new file mode 100644
index 0000000000000..92ac91ac51769
--- /dev/null
+++ b/docs/sql-ref-syntax-qry-select-usedb.md
@@ -0,0 +1,60 @@
+---
+layout: global
+title: USE Database
+displayTitle: USE Database
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+### Description
+`USE` statement is used to set the current database. After the current database is set,
+the unqualified database artifacts such as tables, functions and views that are
+referenced by SQLs are resolved from the current database.
+The default database name is 'default'.
+
+### Syntax
+{% highlight sql %}
+USE database_name
+{% endhighlight %}
+
+### Parameter
+
+
+ database_name
+
+ Name of the database will be used. If the database does not exist, an exception will be thrown.
+
+
+
+### Example
+{% highlight sql %}
+-- Use the 'userdb' which exists.
+USE userdb;
++---------+--+
+| Result |
++---------+--+
++---------+--+
+
+-- Use the 'userdb1' which doesn't exist
+USE userdb1;
+Error: org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException: Database 'userdb1' not found;(state=,code=0)
+{% endhighlight %}
+
+### Related statements.
+- [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html)
+- [DROP DATABASE](sql-ref-syntax-ddl-drop-database.html)
+- [CREATE TABLE ](sql-ref-syntax-ddl-create-table.html)
+
diff --git a/docs/ss-migration-guide.md b/docs/ss-migration-guide.md
new file mode 100644
index 0000000000000..b0fd8a8325dff
--- /dev/null
+++ b/docs/ss-migration-guide.md
@@ -0,0 +1,32 @@
+---
+layout: global
+title: "Migration Guide: Structured Streaming"
+displayTitle: "Migration Guide: Structured Streaming"
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+* Table of contents
+{:toc}
+
+Note that this migration guide describes the items specific to Structured Streaming.
+Many items of SQL migration can be applied when migrating Structured Streaming to higher versions.
+Please refer [Migration Guide: SQL, Datasets and DataFrame](sql-migration-guide.html).
+
+## Upgrading from Structured Streaming 2.4 to 3.0
+
+- In Spark 3.0, Structured Streaming forces the source schema into nullable when file-based datasources such as text, json, csv, parquet and orc are used via `spark.readStream(...)`. Previously, it respected the nullability in source schema; however, it caused issues tricky to debug with NPE. To restore the previous behavior, set `spark.sql.streaming.fileSource.schema.forceNullable` to `false`.
+
diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md
index 55acec53302e4..3389d453c2cbd 100644
--- a/docs/streaming-kinesis-integration.md
+++ b/docs/streaming-kinesis-integration.md
@@ -64,13 +64,13 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m
- import org.apache.spark.storage.StorageLevel
- import org.apache.spark.streaming.kinesis.KinesisInputDStream
- import org.apache.spark.streaming.Seconds
- import org.apache.spark.streaming.StreamingContext
- import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+ import org.apache.spark.storage.StorageLevel;
+ import org.apache.spark.streaming.kinesis.KinesisInputDStream;
+ import org.apache.spark.streaming.Seconds;
+ import org.apache.spark.streaming.StreamingContext;
+ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream;
- KinesisInputDStream kinesisStream = KinesisInputDStream.builder
+ KinesisInputDStream kinesisStream = KinesisInputDStream.builder()
.streamingContext(streamingContext)
.endpointUrl([endpoint URL])
.regionName([region name])
@@ -81,7 +81,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m
.storageLevel(StorageLevel.MEMORY_AND_DISK_2)
.build();
- See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html)
+ See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisInputDStream.html)
and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the [Running the Example](#running-the-example) subsection for instructions to run the example.
@@ -98,14 +98,21 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m
- You may also provide a "message handler function" that takes a Kinesis `Record` and returns a generic object `T`, in case you would like to use other data included in a `Record` such as partition key. This is currently only supported in Scala and Java.
+ You may also provide the following settings. These are currently only supported in Scala and Java.
+
+ - A "message handler function" that takes a Kinesis `Record` and returns a generic object `T`, in case you would like to use other data included in a `Record` such as partition key.
+
+ - CloudWatch metrics level and dimensions. See [the AWS documentation about monitoring KCL](https://docs.aws.amazon.com/streams/latest/dev/monitoring-with-kcl.html) for details.
+ import collection.JavaConverters._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.kinesis.KinesisInputDStream
import org.apache.spark.streaming.{Seconds, StreamingContext}
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration
+ import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel
val kinesisStream = KinesisInputDStream.builder
.streamingContext(streamingContext)
@@ -116,17 +123,22 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m
.checkpointAppName([Kinesis app name])
.checkpointInterval([checkpoint interval])
.storageLevel(StorageLevel.MEMORY_AND_DISK_2)
+ .metricsLevel(MetricsLevel.DETAILED)
+ .metricsEnabledDimensions(KinesisClientLibConfiguration.DEFAULT_METRICS_ENABLED_DIMENSIONS.asScala.toSet)
.buildWithMessageHandler([message handler])
- import org.apache.spark.storage.StorageLevel
- import org.apache.spark.streaming.kinesis.KinesisInputDStream
- import org.apache.spark.streaming.Seconds
- import org.apache.spark.streaming.StreamingContext
- import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
-
- KinesisInputDStream kinesisStream = KinesisInputDStream.builder
+ import org.apache.spark.storage.StorageLevel;
+ import org.apache.spark.streaming.kinesis.KinesisInputDStream;
+ import org.apache.spark.streaming.Seconds;
+ import org.apache.spark.streaming.StreamingContext;
+ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream;
+ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration;
+ import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel;
+ import scala.collection.JavaConverters;
+
+ KinesisInputDStream kinesisStream = KinesisInputDStream.builder()
.streamingContext(streamingContext)
.endpointUrl([endpoint URL])
.regionName([region name])
@@ -135,6 +147,8 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m
.checkpointAppName([Kinesis app name])
.checkpointInterval([checkpoint interval])
.storageLevel(StorageLevel.MEMORY_AND_DISK_2)
+ .metricsLevel(MetricsLevel.DETAILED)
+ .metricsEnabledDimensions(JavaConverters.asScalaSetConverter(KinesisClientLibConfiguration.DEFAULT_METRICS_ENABLED_DIMENSIONS).asScala().toSet())
.buildWithMessageHandler([message handler]);
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index f5abed74bff20..f6b579fbf74d1 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -2488,13 +2488,13 @@ additional effort may be necessary to achieve exactly-once semantics. There are
* [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) and
[DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream)
* [KafkaUtils](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$),
- [KinesisUtils](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$),
+ [KinesisUtils](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisInputDStream),
- Java docs
* [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html),
[JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) and
[JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html)
* [KafkaUtils](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html),
- [KinesisUtils](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html)
+ [KinesisUtils](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisInputDStream.html)
- Python docs
* [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) and [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream)
* [KafkaUtils](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils)
diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md
index c4378b4a02663..89732d309aa27 100644
--- a/docs/structured-streaming-kafka-integration.md
+++ b/docs/structured-streaming-kafka-integration.md
@@ -27,6 +27,8 @@ For Scala/Java applications using SBT/Maven project definitions, link your appli
artifactId = spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}}
version = {{site.SPARK_VERSION_SHORT}}
+Please note that to use the headers functionality, your Kafka client version should be version 0.11.0.0 or up.
+
For Python applications, you need to add this above library and its dependencies when deploying your
application. See the [Deploying](#deploying) subsection below.
@@ -50,6 +52,17 @@ val df = spark
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.as[(String, String)]
+// Subscribe to 1 topic, with headers
+val df = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", "host1:port1,host2:port2")
+ .option("subscribe", "topic1")
+ .option("includeHeaders", "true")
+ .load()
+df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers")
+ .as[(String, String, Map)]
+
// Subscribe to multiple topics
val df = spark
.readStream
@@ -84,6 +97,16 @@ Dataset
df = spark
.load();
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)");
+// Subscribe to 1 topic, with headers
+Dataset df = spark
+ .readStream()
+ .format("kafka")
+ .option("kafka.bootstrap.servers", "host1:port1,host2:port2")
+ .option("subscribe", "topic1")
+ .option("includeHeaders", "true")
+ .load()
+df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers");
+
// Subscribe to multiple topics
Dataset df = spark
.readStream()
@@ -116,6 +139,16 @@ df = spark \
.load()
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+# Subscribe to 1 topic, with headers
+val df = spark \
+ .readStream \
+ .format("kafka") \
+ .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \
+ .option("subscribe", "topic1") \
+ .option("includeHeaders", "true") \
+ .load()
+df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers")
+
# Subscribe to multiple topics
df = spark \
.readStream \
@@ -286,6 +319,10 @@ Each row in the source has the following schema:
timestampType
int
+
+ headers (optional)
+ array
+
The following options must be set for the Kafka source
@@ -325,6 +362,27 @@ The following configurations are optional:
Option value default query type meaning
+
+ startingOffsetsByTimestamp
+ json string
+ """ {"topicA":{"0": 1000, "1": 1000}, "topicB": {"0": 2000, "1": 2000}} """
+
+ none (the value of startingOffsets
will apply)
+ streaming and batch
+ The start point of timestamp when a query is started, a json string specifying a starting timestamp for
+ each TopicPartition. The returned offset for each partition is the earliest offset whose timestamp is greater than or
+ equal to the given timestamp in the corresponding partition. If the matched offset doesn't exist,
+ the query will fail immediately to prevent unintended read from such partition. (This is a kind of limitation as of now, and will be addressed in near future.)
+
+ Spark simply passes the timestamp information to KafkaConsumer.offsetsForTimes
, and doesn't interpret or reason about the value.
+ For more details on KafkaConsumer.offsetsForTimes
, please refer javadoc for details.
+ Also the meaning of timestamp
here can be vary according to Kafka configuration (log.message.timestamp.type
): please refer Kafka documentation for further details.
+ Note: This option requires Kafka 0.10.1.0 or higher.
+ Note2: startingOffsetsByTimestamp
takes precedence over startingOffsets
.
+ Note3: For streaming queries, this only applies when a new query is started, and that resuming will
+ always pick up from where the query left off. Newly discovered partitions during a query will start at
+ earliest.
+
startingOffsets
"earliest", "latest" (streaming only), or json string
@@ -340,6 +398,25 @@ The following configurations are optional:
always pick up from where the query left off. Newly discovered partitions during a query will start at
earliest.
+
+ endingOffsetsByTimestamp
+ json string
+ """ {"topicA":{"0": 1000, "1": 1000}, "topicB": {"0": 2000, "1": 2000}} """
+
+ latest
+ batch query
+ The end point when a batch query is ended, a json string specifying an ending timesamp for each TopicPartition.
+ The returned offset for each partition is the earliest offset whose timestamp is greater than or equal to
+ the given timestamp in the corresponding partition. If the matched offset doesn't exist, the offset will
+ be set to latest.
+
+ Spark simply passes the timestamp information to KafkaConsumer.offsetsForTimes
, and doesn't interpret or reason about the value.
+ For more details on KafkaConsumer.offsetsForTimes
, please refer javadoc for details.
+ Also the meaning of timestamp
here can be vary according to Kafka configuration (log.message.timestamp.type
): please refer Kafka documentation for further details.
+ Note: This option requires Kafka 0.10.1.0 or higher.
+ Note2: endingOffsetsByTimestamp
takes precedence over endingOffsets
.
+
+
endingOffsets
latest or json string
@@ -425,6 +502,13 @@ The following configurations are optional:
issues, set the Kafka consumer session timeout (by setting option "kafka.session.timeout.ms") to
be very small. When this is set, option "groupIdPrefix" will be ignored.
+
+ includeHeaders
+ boolean
+ false
+ streaming and batch
+ Whether to include the Kafka headers in the row.
+
### Consumer Caching
@@ -522,6 +606,10 @@ The Dataframe being written to Kafka should have the following columns in schema
value (required)
string or binary
+
+ headers (optional)
+ array
+
topic (*optional)
string
@@ -559,6 +647,13 @@ The following configurations are optional:
Sets the topic that all rows will be written to in Kafka. This option overrides any
topic column that may exist in the data.
+
+ includeHeaders
+ boolean
+ false
+ streaming and batch
+ Whether to include the Kafka headers in the row.
+
### Creating a Kafka Sink for Streaming Queries
@@ -825,7 +920,9 @@ Delegation tokens can be obtained from multiple clusters and ${cluster}spark.kafka.clusters.${cluster}.security.protocol
SASL_SSL
- Protocol used to communicate with brokers. For further details please see Kafka documentation. Only used to obtain delegation token.
+ Protocol used to communicate with brokers. For further details please see Kafka documentation. Protocol is applied on all the sources and sinks as default where
+ bootstrap.servers
config matches (for further details please see spark.kafka.clusters.${cluster}.target.bootstrap.servers.regex
),
+ and can be overridden by setting kafka.security.protocol
on the source or sink.
diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md
index deaf262c5f572..2a405f36fd5fd 100644
--- a/docs/structured-streaming-programming-guide.md
+++ b/docs/structured-streaming-programming-guide.md
@@ -1505,7 +1505,6 @@ Additional details on supported joins:
- Cannot use mapGroupsWithState and flatMapGroupsWithState in Update mode before joins.
-
### Streaming Deduplication
You can deduplicate records in data streams using a unique identifier in the events. This is exactly same as deduplication on static using a unique identifier column. The query will store the necessary amount of data from previous records such that it can filter duplicate records. Similar to aggregations, you can use deduplication with or without watermarking.
@@ -1616,6 +1615,8 @@ this configuration judiciously.
### Arbitrary Stateful Operations
Many usecases require more advanced stateful operations than aggregations. For example, in many usecases, you have to track sessions from data streams of events. For doing such sessionization, you will have to save arbitrary types of data as state, and perform arbitrary operations on the state using the data stream events in every trigger. Since Spark 2.2, this can be done using the operation `mapGroupsWithState` and the more powerful operation `flatMapGroupsWithState`. Both operations allow you to apply user-defined code on grouped Datasets to update user-defined state. For more concrete details, take a look at the API documentation ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.GroupState)/[Java](api/java/org/apache/spark/sql/streaming/GroupState.html)) and the examples ([Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java)).
+Though Spark cannot check and force it, the state function should be implemented with respect to the semantics of the output mode. For example, in Update mode Spark doesn't expect that the state function will emit rows which are older than current watermark plus allowed late record delay, whereas in Append mode the state function can emit these rows.
+
### Unsupported Operations
There are a few DataFrame/Dataset operations that are not supported with streaming DataFrames/Datasets.
Some of them are as follows.
@@ -1647,6 +1648,26 @@ For example, sorting on the input stream is not supported, as it requires keepin
track of all the data received in the stream. This is therefore fundamentally hard to execute
efficiently.
+### Limitation of global watermark
+
+In Append mode, if a stateful operation emits rows older than current watermark plus allowed late record delay,
+they will be "late rows" in downstream stateful operations (as Spark uses global watermark). Note that these rows may be discarded.
+This is a limitation of a global watermark, and it could potentially cause a correctness issue.
+
+Spark will check the logical plan of query and log a warning when Spark detects such a pattern.
+
+Any of the stateful operation(s) after any of below stateful operations can have this issue:
+
+* streaming aggregation in Append mode
+* stream-stream outer join
+* `mapGroupsWithState` and `flatMapGroupsWithState` in Append mode (depending on the implementation of the state function)
+
+As Spark cannot check the state function of `mapGroupsWithState`/`flatMapGroupsWithState`, Spark assumes that the state function
+emits late rows if the operator uses Append mode.
+
+There's a known workaround: split your streaming query into multiple queries per stateful operator, and ensure
+end-to-end exactly once per query. Ensuring end-to-end exactly once for the last query is optional.
+
## Starting Streaming Queries
Once you have defined the final result DataFrame/Dataset, all that is left is for you to start the streaming computation. To do that, you have to use the `DataStreamWriter`
([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamWriter)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamWriter.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamWriter) docs)
diff --git a/docs/web-ui.md b/docs/web-ui.md
index 72423d9468e83..e6025370e6796 100644
--- a/docs/web-ui.md
+++ b/docs/web-ui.md
@@ -404,3 +404,44 @@ The web UI includes a Streaming tab if the application uses Spark streaming. Thi
scheduling delay and processing time for each micro-batch in the data stream, which can be useful
for troubleshooting the streaming application.
+## JDBC/ODBC Server Tab
+We can see this tab when Spark is running as a [distributed SQL engine](sql-distributed-sql-engine.html). It shows information about sessions and submitted SQL operations.
+
+The first section of the page displays general information about the JDBC/ODBC server: start time and uptime.
+
+
+
+
+
+The second section contains information about active and finished sessions.
+* **User** and **IP** of the connection.
+* **Session id** link to access to session info.
+* **Start time**, **finish time** and **duration** of the session.
+* **Total execute** is the number of operations submitted in this session.
+
+
+
+
+
+The third section has the SQL statistics of the submitted operations.
+* **User** that submit the operation.
+* **Job id** link to [jobs tab](web-ui.html#jobs-tab).
+* **Group id** of the query that group all jobs together. An application can cancel all running jobs using this group id.
+* **Start time** of the operation.
+* **Finish time** of the execution, before fetching the results.
+* **Close time** of the operation after fetching the results.
+* **Execution time** is the difference between finish time and start time.
+* **Duration time** is the difference between close time and start time.
+* **Statement** is the operation being executed.
+* **State** of the process.
+ * _Started_, first state, when the process begins.
+ * _Compiled_, execution plan generated.
+ * _Failed_, final state when the execution failed or finished with error.
+ * _Canceled_, final state when the execution is canceled.
+ * _Finished_ processing and waiting to fetch results.
+ * _Closed_, final state when client closed the statement.
+* **Detail** of the execution plan with parsed logical plan, analyzed logical plan, optimized logical plan and physical plan or errors in the the SQL statement.
+
+
+
+
diff --git a/examples/pom.xml b/examples/pom.xml
index ac148ef4c9c01..a099f1e042e99 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -107,7 +107,7 @@
com.github.scopt
scopt_${scala.binary.version}
- 3.7.0
+ 3.7.1
${hive.parquet.group}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java
deleted file mode 100644
index 324a781c1a44a..0000000000000
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java
+++ /dev/null
@@ -1,81 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.mllib;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaSparkContext;
-
-// $example on$
-import scala.Tuple2;
-
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.regression.LinearRegressionModel;
-import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
-// $example off$
-
-/**
- * Example for LinearRegressionWithSGD.
- */
-public class JavaLinearRegressionWithSGDExample {
- public static void main(String[] args) {
- SparkConf conf = new SparkConf().setAppName("JavaLinearRegressionWithSGDExample");
- JavaSparkContext sc = new JavaSparkContext(conf);
-
- // $example on$
- // Load and parse the data
- String path = "data/mllib/ridge-data/lpsa.data";
- JavaRDD data = sc.textFile(path);
- JavaRDD parsedData = data.map(line -> {
- String[] parts = line.split(",");
- String[] features = parts[1].split(" ");
- double[] v = new double[features.length];
- for (int i = 0; i < features.length - 1; i++) {
- v[i] = Double.parseDouble(features[i]);
- }
- return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v));
- });
- parsedData.cache();
-
- // Building the model
- int numIterations = 100;
- double stepSize = 0.00000001;
- LinearRegressionModel model =
- LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations, stepSize);
-
- // Evaluate model on training examples and compute training error
- JavaPairRDD valuesAndPreds = parsedData.mapToPair(point ->
- new Tuple2<>(model.predict(point.features()), point.label()));
-
- double MSE = valuesAndPreds.mapToDouble(pair -> {
- double diff = pair._1() - pair._2();
- return diff * diff;
- }).mean();
- System.out.println("training Mean Squared Error = " + MSE);
-
- // Save and load model
- model.save(sc.sc(), "target/tmp/javaLinearRegressionWithSGDModel");
- LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(),
- "target/tmp/javaLinearRegressionWithSGDModel");
- // $example off$
-
- sc.stop();
- }
-}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java
deleted file mode 100644
index 00033b5730a3d..0000000000000
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.mllib;
-
-// $example on$
-import scala.Tuple2;
-
-import org.apache.spark.api.java.*;
-import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.regression.LinearRegressionModel;
-import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
-import org.apache.spark.mllib.evaluation.RegressionMetrics;
-import org.apache.spark.SparkConf;
-// $example off$
-
-public class JavaRegressionMetricsExample {
- public static void main(String[] args) {
- SparkConf conf = new SparkConf().setAppName("Java Regression Metrics Example");
- JavaSparkContext sc = new JavaSparkContext(conf);
- // $example on$
- // Load and parse the data
- String path = "data/mllib/sample_linear_regression_data.txt";
- JavaRDD data = sc.textFile(path);
- JavaRDD parsedData = data.map(line -> {
- String[] parts = line.split(" ");
- double[] v = new double[parts.length - 1];
- for (int i = 1; i < parts.length; i++) {
- v[i - 1] = Double.parseDouble(parts[i].split(":")[1]);
- }
- return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v));
- });
- parsedData.cache();
-
- // Building the model
- int numIterations = 100;
- LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData),
- numIterations);
-
- // Evaluate model on training examples and compute training error
- JavaPairRDD valuesAndPreds = parsedData.mapToPair(point ->
- new Tuple2<>(model.predict(point.features()), point.label()));
-
- // Instantiate metrics object
- RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd());
-
- // Squared error
- System.out.format("MSE = %f\n", metrics.meanSquaredError());
- System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError());
-
- // R-squared
- System.out.format("R Squared = %f\n", metrics.r2());
-
- // Mean absolute error
- System.out.format("MAE = %f\n", metrics.meanAbsoluteError());
-
- // Explained variance
- System.out.format("Explained Variance = %f\n", metrics.explainedVariance());
-
- // Save and load model
- model.save(sc.sc(), "target/tmp/LogisticRegressionModel");
- LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(),
- "target/tmp/LogisticRegressionModel");
- // $example off$
-
- sc.stop();
- }
-}
diff --git a/examples/src/main/scala/org/apache/spark/examples/AccumulatorMetricsTest.scala b/examples/src/main/scala/org/apache/spark/examples/AccumulatorMetricsTest.scala
index 5d9a9a73f12ec..36da10568989d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/AccumulatorMetricsTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/AccumulatorMetricsTest.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.SparkSession
* accumulator source) are reported to stdout as well.
*/
object AccumulatorMetricsTest {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
index 3311de12dbd97..d7e79966037cc 100644
--- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.SparkSession
* Usage: BroadcastTest [partitions] [numElem] [blockSize]
*/
object BroadcastTest {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val blockSize = if (args.length > 2) args(2) else "4096"
diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala
index d12ef642bd2cd..ed56108f4b624 100644
--- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala
@@ -27,7 +27,7 @@ import org.apache.spark.util.Utils
* test driver submission in the standalone scheduler.
*/
object DriverSubmissionTest {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 1) {
println("Usage: DriverSubmissionTest ")
System.exit(0)
diff --git a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala
index 45c4953a84be2..6e95318a8cbc0 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala
@@ -20,7 +20,7 @@ package org.apache.spark.examples
import org.apache.spark.sql.SparkSession
object ExceptionHandlingTest {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("ExceptionHandlingTest")
diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala
index 2f2bbb1275438..c07c1afbcb174 100644
--- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.SparkSession
* Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]
*/
object GroupByTest {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("GroupBy Test")
diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala
index b327e13533b81..48698678571e3 100644
--- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.SparkSession
object HdfsTest {
/** Usage: HdfsTest [file] */
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 1) {
System.err.println("Usage: HdfsTest ")
System.exit(1)
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
index 3f9cea35d6503..87c2f6853807a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
@@ -93,7 +93,7 @@ object LocalALS {
new CholeskyDecomposition(XtX).getSolver.solve(Xty)
}
- def showWarning() {
+ def showWarning(): Unit = {
System.err.println(
"""WARN: This is a naive implementation of ALS and is given as an example!
|Please use org.apache.spark.ml.recommendation.ALS
@@ -101,7 +101,7 @@ object LocalALS {
""".stripMargin)
}
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
args match {
case Array(m, u, f, iters) =>
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala
index 5512e33e41ac3..5478c585a959e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala
@@ -39,7 +39,7 @@ object LocalFileLR {
DataPoint(new DenseVector(nums.slice(1, D + 1)), nums(0))
}
- def showWarning() {
+ def showWarning(): Unit = {
System.err.println(
"""WARN: This is a naive implementation of Logistic Regression and is given as an example!
|Please use org.apache.spark.ml.classification.LogisticRegression
@@ -47,7 +47,7 @@ object LocalFileLR {
""".stripMargin)
}
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
showWarning()
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala
index f5162a59522f0..4a73466841f69 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala
@@ -62,7 +62,7 @@ object LocalKMeans {
bestIndex
}
- def showWarning() {
+ def showWarning(): Unit = {
System.err.println(
"""WARN: This is a naive implementation of KMeans Clustering and is given as an example!
|Please use org.apache.spark.ml.clustering.KMeans
@@ -70,7 +70,7 @@ object LocalKMeans {
""".stripMargin)
}
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
showWarning()
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala
index bde8ccd305960..4ca0ecdcfe6e0 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala
@@ -46,7 +46,7 @@ object LocalLR {
Array.tabulate(N)(generatePoint)
}
- def showWarning() {
+ def showWarning(): Unit = {
System.err.println(
"""WARN: This is a naive implementation of Logistic Regression and is given as an example!
|Please use org.apache.spark.ml.classification.LogisticRegression
@@ -54,7 +54,7 @@ object LocalLR {
""".stripMargin)
}
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
showWarning()
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala
index a93c15c85cfc1..7660ffd02ed9b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala
@@ -21,7 +21,7 @@ package org.apache.spark.examples
import scala.math.random
object LocalPi {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
var count = 0
for (i <- 1 to 100000) {
val x = random * 2 - 1
diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
index 03187aee044e4..e2120eaee6e5a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
@@ -41,7 +41,7 @@ object LogQuery {
| 0 73.23.2.15 images.com 1358492557 - Whatup""".stripMargin.split('\n').mkString
)
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setAppName("Log Query")
val sc = new SparkContext(sparkConf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala
index e6f33b7adf5d1..4bea5cae775cb 100644
--- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.SparkSession
* Usage: MultiBroadcastTest [partitions] [numElem]
*/
object MultiBroadcastTest {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala
index 2332a661f26a0..2bd7c3e954396 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.SparkSession
* Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio]
*/
object SimpleSkewedGroupByTest {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("SimpleSkewedGroupByTest")
diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala
index 4d3c34041bc17..2e7abd62dcdc6 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.SparkSession
* Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]
*/
object SkewedGroupByTest {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("GroupBy Test")
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
index d3e7b7a967de7..651f0224d4402 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
@@ -78,7 +78,7 @@ object SparkALS {
new CholeskyDecomposition(XtX).getSolver.solve(Xty)
}
- def showWarning() {
+ def showWarning(): Unit = {
System.err.println(
"""WARN: This is a naive implementation of ALS and is given as an example!
|Please use org.apache.spark.ml.recommendation.ALS
@@ -86,7 +86,7 @@ object SparkALS {
""".stripMargin)
}
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
var slices = 0
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
index 23eaa879114a9..8c09ce614d931 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
@@ -49,7 +49,7 @@ object SparkHdfsLR {
DataPoint(new DenseVector(x), y)
}
- def showWarning() {
+ def showWarning(): Unit = {
System.err.println(
"""WARN: This is a naive implementation of Logistic Regression and is given as an example!
|Please use org.apache.spark.ml.classification.LogisticRegression
@@ -57,7 +57,7 @@ object SparkHdfsLR {
""".stripMargin)
}
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 2) {
System.err.println("Usage: SparkHdfsLR ")
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
index b005cb6971c16..ec9b44ce6e3b7 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
@@ -49,7 +49,7 @@ object SparkKMeans {
bestIndex
}
- def showWarning() {
+ def showWarning(): Unit = {
System.err.println(
"""WARN: This is a naive implementation of KMeans Clustering and is given as an example!
|Please use org.apache.spark.ml.clustering.KMeans
@@ -57,7 +57,7 @@ object SparkKMeans {
""".stripMargin)
}
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 3) {
System.err.println("Usage: SparkKMeans ")
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
index 4b1497345af82..deb6668f7ecfc 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
@@ -51,7 +51,7 @@ object SparkLR {
Array.tabulate(N)(generatePoint)
}
- def showWarning() {
+ def showWarning(): Unit = {
System.err.println(
"""WARN: This is a naive implementation of Logistic Regression and is given as an example!
|Please use org.apache.spark.ml.classification.LogisticRegression
@@ -59,7 +59,7 @@ object SparkLR {
""".stripMargin)
}
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
showWarning()
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
index 9299bad5d3290..3bd475c440d72 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.SparkSession
*/
object SparkPageRank {
- def showWarning() {
+ def showWarning(): Unit = {
System.err.println(
"""WARN: This is a naive implementation of PageRank and is given as an example!
|Please use the PageRank implementation found in org.apache.spark.graphx.lib.PageRank
@@ -47,7 +47,7 @@ object SparkPageRank {
""".stripMargin)
}
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 1) {
System.err.println("Usage: SparkPageRank ")
System.exit(1)
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
index 828d98b5001d7..a8eec6a99cf4b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.SparkSession
/** Computes an approximation to pi */
object SparkPi {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("Spark Pi")
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala b/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala
index 64076f2deb706..99a12b9442365 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.SparkSession
/** Usage: SparkRemoteFileTest [file] */
object SparkRemoteFileTest {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 1) {
System.err.println("Usage: SparkRemoteFileTest ")
System.exit(1)
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala
index f5d42141f5dd2..7a6fa9a797ff9 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala
@@ -41,7 +41,7 @@ object SparkTC {
edges.toSeq
}
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("SparkTC")
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala
index da3ffca1a6f2a..af18c0afbb223 100644
--- a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala
@@ -23,7 +23,7 @@ package org.apache.spark.examples.graphx
* http://snap.stanford.edu/data/soc-LiveJournal1.html.
*/
object LiveJournalPageRank {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 1) {
System.err.println(
"Usage: LiveJournalPageRank \n" +
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala
index 57b2edf992208..8bc9c0a86eab6 100644
--- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala
@@ -47,7 +47,7 @@ object SynthBenchmark {
* -degFile the local file to save the degree information (Default: Empty)
* -seed seed to use for RNGs (Default: -1, picks seed randomly)
*/
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val options = args.map {
arg =>
arg.dropWhile(_ == '-').split('=') match {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala
index 8091838a2301e..354e65c2bae38 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala
@@ -42,7 +42,7 @@ object ALSExample {
}
// $example off$
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("ALSExample")
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala
index 5638e66b8792a..1a67a6e755ab4 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.SparkSession
object ChiSqSelectorExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("ChiSqSelectorExample")
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala
index 91d861dd4380a..947ca5f5fb5e1 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel}
import org.apache.spark.sql.SparkSession
object CountVectorizerExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("CountVectorizerExample")
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala
index ee4469faab3a0..4377efd9e95fa 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala
@@ -41,7 +41,7 @@ object DataFrameExample {
case class Params(input: String = "data/mllib/sample_libsvm_data.txt")
extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("DataFrameExample") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
index 19f2d7751bc54..ef38163d7eb0d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
@@ -65,7 +65,7 @@ object DecisionTreeExample {
checkpointDir: Option[String] = None,
checkpointInterval: Int = 10) extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("DecisionTreeExample") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index 2dc11b07d88ef..9b5dfed0cb31b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.{Dataset, Row, SparkSession}
*/
object DeveloperApiExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("DeveloperApiExample")
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
index 8f3ce4b315bd3..ca4235d53e636 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
@@ -63,7 +63,7 @@ object GBTExample {
checkpointDir: Option[String] = None,
checkpointInterval: Int = 10) extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("GBTExample") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala
index 2940682c32801..b3642c0b45db6 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.feature.{IndexToString, StringIndexer}
import org.apache.spark.sql.SparkSession
object IndexToStringExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("IndexToStringExample")
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala
index 6903a1c298ced..370c6fd7c17fc 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala
@@ -50,7 +50,7 @@ object LinearRegressionExample {
tol: Double = 1E-6,
fracTest: Double = 0.2) extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("LinearRegressionExample") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala
index bd6cc8cff2348..b64ab4792add4 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala
@@ -55,7 +55,7 @@ object LogisticRegressionExample {
tol: Double = 1E-6,
fracTest: Double = 0.2) extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("LogisticRegressionExample") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
index 4ad6c7c3ef202..86e70e8ab0189 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.SparkSession
*/
object OneVsRestExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName(s"OneVsRestExample")
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala
index 0fe16fb6dfa9f..55823fe1832e5 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.feature.QuantileDiscretizer
import org.apache.spark.sql.SparkSession
object QuantileDiscretizerExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("QuantileDiscretizerExample")
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
index 3c127a46e1f10..6ba14bcd1822f 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
@@ -64,7 +64,7 @@ object RandomForestExample {
checkpointDir: Option[String] = None,
checkpointInterval: Int = 10) extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("RandomForestExample") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala
index bb4587b82cb37..bf6a4846b6e34 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.feature.SQLTransformer
import org.apache.spark.sql.SparkSession
object SQLTransformerExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("SQLTransformerExample")
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala
index ec2df2ef876ba..6121c81cd1f5d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.SparkSession
object TfIdfExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("TfIdfExample")
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala
index b4179ecc1e56d..05f2ee3288624 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala
@@ -82,7 +82,7 @@ object UnaryTransformerExample {
object MyTransformer extends DefaultParamsReadable[MyTransformer]
// $example off$
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName("UnaryTransformerExample")
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala
index 4bcc6ac6a01f5..8ff0e8c6a51c8 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
object Word2VecExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("Word2Vec example")
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala
index a07535bb5a38d..1a7839414b38e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala
@@ -26,7 +26,7 @@ import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
object AssociationRulesExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("AssociationRulesExample")
val sc = new SparkContext(conf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala
index e3cc1d9c83361..6fc3501fc57b5 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala
@@ -58,7 +58,7 @@ object BinaryClassification {
regType: RegType = L2,
regParam: Double = 0.01) extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("BinaryClassification") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala
index 53d0b8fc208ef..b7f0ba00f913e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala
@@ -34,7 +34,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
*/
object BisectingKMeansExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setAppName("mllib.BisectingKMeansExample")
val sc = new SparkContext(sparkConf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala
index 0b44c339ef139..cf9f7adbf6999 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala
@@ -37,7 +37,7 @@ object Correlations {
case class Params(input: String = "data/mllib/sample_linear_regression_data.txt")
extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala
index 681465d2176d4..9082f0b5a8b85 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala
@@ -45,7 +45,7 @@ object CosineSimilarity {
case class Params(inputFile: String = null, threshold: Double = 0.1)
extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("CosineSimilarity") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index b5d1b02f92524..1029ca04c348f 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -67,7 +67,7 @@ object DecisionTreeRunner {
checkpointDir: Option[String] = None,
checkpointInterval: Int = 10) extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("DecisionTreeRunner") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala
index b228827e5886f..0259df2799174 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala
@@ -47,7 +47,7 @@ object DenseKMeans {
numIterations: Int = 10,
initializationMode: InitializationMode = Parallel) extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("DenseKMeans") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala
index f724ee1030f04..a25ce826ee842 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala
@@ -35,7 +35,7 @@ object FPGrowthExample {
minSupport: Double = 0.3,
numPartition: Int = -1) extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("FPGrowthExample") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GaussianMixtureExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GaussianMixtureExample.scala
index b1b3a79d87ae1..103d212a80e78 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/GaussianMixtureExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GaussianMixtureExample.scala
@@ -26,7 +26,7 @@ import org.apache.spark.mllib.linalg.Vectors
object GaussianMixtureExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("GaussianMixtureExample")
val sc = new SparkContext(conf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
index 3f264933cd3cc..12e0c8df274b2 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
@@ -50,7 +50,7 @@ object GradientBoostedTreesRunner {
numIterations: Int = 10,
fracTest: Double = 0.2) extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("GradientBoostedTrees") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala
index 9b3c3266ee30a..8435209377553 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala
@@ -29,7 +29,7 @@ import org.apache.spark.rdd.RDD
object HypothesisTestingExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("HypothesisTestingExample")
val sc = new SparkContext(conf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala
index b0a6f1671a898..17ebd4159b8d7 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala
@@ -26,7 +26,7 @@ import org.apache.spark.mllib.linalg.Vectors
object KMeansExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("KMeansExample")
val sc = new SparkContext(conf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
index cd77ecf990b3b..605ca68e627ec 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
@@ -53,7 +53,7 @@ object LDAExample {
checkpointDir: Option[String] = None,
checkpointInterval: Int = 10) extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("LDAExample") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala
index d25962c5500ed..55a45b302b5a3 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala
@@ -26,7 +26,7 @@ import org.apache.spark.mllib.linalg.Vectors
object LatentDirichletAllocationExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("LatentDirichletAllocationExample")
val sc = new SparkContext(conf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala
deleted file mode 100644
index 03222b13ad27d..0000000000000
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// scalastyle:off println
-package org.apache.spark.examples.mllib
-
-import org.apache.log4j.{Level, Logger}
-import scopt.OptionParser
-
-import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.mllib.optimization.{L1Updater, SimpleUpdater, SquaredL2Updater}
-import org.apache.spark.mllib.regression.LinearRegressionWithSGD
-import org.apache.spark.mllib.util.MLUtils
-
-/**
- * An example app for linear regression. Run with
- * {{{
- * bin/run-example org.apache.spark.examples.mllib.LinearRegression
- * }}}
- * A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt`.
- * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
- */
-@deprecated("Use ml.regression.LinearRegression or LBFGS", "2.0.0")
-object LinearRegression {
-
- object RegType extends Enumeration {
- type RegType = Value
- val NONE, L1, L2 = Value
- }
-
- import RegType._
-
- case class Params(
- input: String = null,
- numIterations: Int = 100,
- stepSize: Double = 1.0,
- regType: RegType = L2,
- regParam: Double = 0.01) extends AbstractParams[Params]
-
- def main(args: Array[String]) {
- val defaultParams = Params()
-
- val parser = new OptionParser[Params]("LinearRegression") {
- head("LinearRegression: an example app for linear regression.")
- opt[Int]("numIterations")
- .text("number of iterations")
- .action((x, c) => c.copy(numIterations = x))
- opt[Double]("stepSize")
- .text(s"initial step size, default: ${defaultParams.stepSize}")
- .action((x, c) => c.copy(stepSize = x))
- opt[String]("regType")
- .text(s"regularization type (${RegType.values.mkString(",")}), " +
- s"default: ${defaultParams.regType}")
- .action((x, c) => c.copy(regType = RegType.withName(x)))
- opt[Double]("regParam")
- .text(s"regularization parameter, default: ${defaultParams.regParam}")
- arg[String](" ")
- .required()
- .text("input paths to labeled examples in LIBSVM format")
- .action((x, c) => c.copy(input = x))
- note(
- """
- |For example, the following command runs this app on a synthetic dataset:
- |
- | bin/spark-submit --class org.apache.spark.examples.mllib.LinearRegression \
- | examples/target/scala-*/spark-examples-*.jar \
- | data/mllib/sample_linear_regression_data.txt
- """.stripMargin)
- }
-
- parser.parse(args, defaultParams) match {
- case Some(params) => run(params)
- case _ => sys.exit(1)
- }
- }
-
- def run(params: Params): Unit = {
- val conf = new SparkConf().setAppName(s"LinearRegression with $params")
- val sc = new SparkContext(conf)
-
- Logger.getRootLogger.setLevel(Level.WARN)
-
- val examples = MLUtils.loadLibSVMFile(sc, params.input).cache()
-
- val splits = examples.randomSplit(Array(0.8, 0.2))
- val training = splits(0).cache()
- val test = splits(1).cache()
-
- val numTraining = training.count()
- val numTest = test.count()
- println(s"Training: $numTraining, test: $numTest.")
-
- examples.unpersist()
-
- val updater = params.regType match {
- case NONE => new SimpleUpdater()
- case L1 => new L1Updater()
- case L2 => new SquaredL2Updater()
- }
-
- val algorithm = new LinearRegressionWithSGD()
- algorithm.optimizer
- .setNumIterations(params.numIterations)
- .setStepSize(params.stepSize)
- .setUpdater(updater)
- .setRegParam(params.regParam)
-
- val model = algorithm.run(training)
-
- val prediction = model.predict(test.map(_.features))
- val predictionAndLabel = prediction.zip(test.map(_.label))
-
- val loss = predictionAndLabel.map { case (p, l) =>
- val err = p - l
- err * err
- }.reduce(_ + _)
- val rmse = math.sqrt(loss / numTest)
-
- println(s"Test RMSE = $rmse.")
-
- sc.stop()
- }
-}
-// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala
deleted file mode 100644
index 449b725d1d173..0000000000000
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// scalastyle:off println
-package org.apache.spark.examples.mllib
-
-import org.apache.spark.{SparkConf, SparkContext}
-// $example on$
-import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.regression.LinearRegressionModel
-import org.apache.spark.mllib.regression.LinearRegressionWithSGD
-// $example off$
-
-@deprecated("Use ml.regression.LinearRegression or LBFGS", "2.0.0")
-object LinearRegressionWithSGDExample {
-
- def main(args: Array[String]): Unit = {
- val conf = new SparkConf().setAppName("LinearRegressionWithSGDExample")
- val sc = new SparkContext(conf)
-
- // $example on$
- // Load and parse the data
- val data = sc.textFile("data/mllib/ridge-data/lpsa.data")
- val parsedData = data.map { line =>
- val parts = line.split(',')
- LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
- }.cache()
-
- // Building the model
- val numIterations = 100
- val stepSize = 0.00000001
- val model = LinearRegressionWithSGD.train(parsedData, numIterations, stepSize)
-
- // Evaluate model on training examples and compute training error
- val valuesAndPreds = parsedData.map { point =>
- val prediction = model.predict(point.features)
- (point.label, prediction)
- }
- val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2) }.mean()
- println(s"training Mean Squared Error $MSE")
-
- // Save and load model
- model.save(sc, "target/tmp/scalaLinearRegressionWithSGDModel")
- val sameModel = LinearRegressionModel.load(sc, "target/tmp/scalaLinearRegressionWithSGDModel")
- // $example off$
-
- sc.stop()
- }
-}
-// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
index fd810155d6a88..92c85c9271a5a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
@@ -48,7 +48,7 @@ object MovieLensALS {
numProductBlocks: Int = -1,
implicitPrefs: Boolean = false) extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("MovieLensALS") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala
index f9e47e485e72f..b5c52f9a31224 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala
@@ -38,7 +38,7 @@ object MultivariateSummarizer {
case class Params(input: String = "data/mllib/sample_linear_regression_data.txt")
extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala
deleted file mode 100644
index eff2393cc3abe..0000000000000
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// scalastyle:off println
-package org.apache.spark.examples.mllib
-
-import org.apache.spark.SparkConf
-import org.apache.spark.SparkContext
-// $example on$
-import org.apache.spark.mllib.feature.PCA
-import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD}
-// $example off$
-
-@deprecated("Deprecated since LinearRegressionWithSGD is deprecated. Use ml.feature.PCA", "2.0.0")
-object PCAExample {
-
- def main(args: Array[String]): Unit = {
-
- val conf = new SparkConf().setAppName("PCAExample")
- val sc = new SparkContext(conf)
-
- // $example on$
- val data = sc.textFile("data/mllib/ridge-data/lpsa.data").map { line =>
- val parts = line.split(',')
- LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
- }.cache()
-
- val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
- val training = splits(0).cache()
- val test = splits(1)
-
- val pca = new PCA(training.first().features.size / 2).fit(data.map(_.features))
- val training_pca = training.map(p => p.copy(features = pca.transform(p.features)))
- val test_pca = test.map(p => p.copy(features = pca.transform(p.features)))
-
- val numIterations = 100
- val model = LinearRegressionWithSGD.train(training, numIterations)
- val model_pca = LinearRegressionWithSGD.train(training_pca, numIterations)
-
- val valuesAndPreds = test.map { point =>
- val score = model.predict(point.features)
- (score, point.label)
- }
-
- val valuesAndPreds_pca = test_pca.map { point =>
- val score = model_pca.predict(point.features)
- (score, point.label)
- }
-
- val MSE = valuesAndPreds.map { case (v, p) => math.pow((v - p), 2) }.mean()
- val MSE_pca = valuesAndPreds_pca.map { case (v, p) => math.pow((v - p), 2) }.mean()
-
- println(s"Mean Squared Error = $MSE")
- println(s"PCA Mean Squared Error = $MSE_pca")
- // $example off$
-
- sc.stop()
- }
-}
-// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala
index 65603252c4384..eaf1dacd0160a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala
@@ -62,7 +62,7 @@ object PowerIterationClusteringExample {
maxIterations: Int = 15
) extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("PowerIterationClusteringExample") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala
index 8b789277774af..1b5d919a047e8 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala
@@ -25,7 +25,7 @@ import org.apache.spark.mllib.fpm.PrefixSpan
object PrefixSpanExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("PrefixSpanExample")
val sc = new SparkContext(conf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala
index 7ccbb5a0640cd..aee12a1b4751f 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala
@@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD
*/
object RandomRDDGeneration {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName(s"RandomRDDGeneration")
val sc = new SparkContext(conf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala
index ea13ec05e2fad..2845028dd0814 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala
@@ -25,7 +25,7 @@ import org.apache.spark.mllib.recommendation.{ALS, Rating}
import org.apache.spark.sql.SparkSession
object RankingMetricsExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("RankingMetricsExample")
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala
deleted file mode 100644
index 76cfb804e18f3..0000000000000
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-// scalastyle:off println
-
-package org.apache.spark.examples.mllib
-
-// $example on$
-import org.apache.spark.mllib.evaluation.RegressionMetrics
-import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD}
-// $example off$
-import org.apache.spark.sql.SparkSession
-
-@deprecated("Use ml.regression.LinearRegression and the resulting model summary for metrics",
- "2.0.0")
-object RegressionMetricsExample {
- def main(args: Array[String]): Unit = {
- val spark = SparkSession
- .builder
- .appName("RegressionMetricsExample")
- .getOrCreate()
- // $example on$
- // Load the data
- val data = spark
- .read.format("libsvm").load("data/mllib/sample_linear_regression_data.txt")
- .rdd.map(row => LabeledPoint(row.getDouble(0), row.get(1).asInstanceOf[Vector]))
- .cache()
-
- // Build the model
- val numIterations = 100
- val model = LinearRegressionWithSGD.train(data, numIterations)
-
- // Get predictions
- val valuesAndPreds = data.map{ point =>
- val prediction = model.predict(point.features)
- (prediction, point.label)
- }
-
- // Instantiate metrics object
- val metrics = new RegressionMetrics(valuesAndPreds)
-
- // Squared error
- println(s"MSE = ${metrics.meanSquaredError}")
- println(s"RMSE = ${metrics.rootMeanSquaredError}")
-
- // R-squared
- println(s"R-squared = ${metrics.r2}")
-
- // Mean absolute error
- println(s"MAE = ${metrics.meanAbsoluteError}")
-
- // Explained variance
- println(s"Explained variance = ${metrics.explainedVariance}")
- // $example off$
-
- spark.stop()
- }
-}
-// scalastyle:on println
-
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala
index ba3deae5d688f..fdde47d60c544 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala
@@ -35,7 +35,7 @@ object SampledRDDs {
case class Params(input: String = "data/mllib/sample_binary_classification_data.txt")
extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("SampledRDDs") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala
index 694c3bb18b045..ba16e8f5ff347 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala
@@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD
object SimpleFPGrowth {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("SimpleFPGrowth")
val sc = new SparkContext(conf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala
index b76add2f9bc99..b501f4db2efbb 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala
@@ -40,7 +40,7 @@ object SparseNaiveBayes {
numFeatures: Int = -1,
lambda: Double = 1.0) extends AbstractParams[Params]
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val defaultParams = Params()
val parser = new OptionParser[Params]("SparseNaiveBayes") {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala
index 7888af79f87f4..5186f599d9628 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala
@@ -52,7 +52,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext}
*/
object StreamingKMeansExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length != 5) {
System.err.println(
"Usage: StreamingKMeansExample " +
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala
index a8b144a197229..4c72f444ff9ec 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala
@@ -46,7 +46,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext}
*/
object StreamingLogisticRegression {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length != 4) {
System.err.println(
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala
index ae4dee24c6474..f60b10a02274b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala
@@ -44,7 +44,7 @@ import org.apache.spark.util.Utils
*/
object StreamingTestExample {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length != 3) {
// scalastyle:off println
System.err.println(
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala
index 071d341b81614..6b839f3f4ac1e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala
@@ -35,7 +35,7 @@ import org.apache.spark.mllib.linalg.distributed.RowMatrix
* represents a 3-by-2 matrix, whose first row is (0.5, 1.0).
*/
object TallSkinnyPCA {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length != 1) {
System.err.println("Usage: TallSkinnyPCA ")
System.exit(1)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala
index 8ae6de16d80e7..8874c2eda3d2e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala
@@ -35,7 +35,7 @@ import org.apache.spark.mllib.linalg.distributed.RowMatrix
* represents a 3-by-2 matrix, whose first row is (0.5, 1.0).
*/
object TallSkinnySVD {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length != 1) {
System.err.println("Usage: TallSkinnySVD ")
System.exit(1)
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
index deaa9f252b9b0..4fd482d5b8bf7 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession
case class Record(key: Int, value: String)
object RDDRelation {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
// $example on:init_session$
val spark = SparkSession
.builder
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
index c7b6a50f0ae7c..d4c05e5ad9944 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
@@ -24,7 +24,7 @@ object SQLDataSourceExample {
case class Person(name: String, age: Long)
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName("Spark SQL data sources example")
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala
index 678cbc64aff1f..fde281087c267 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala
@@ -34,7 +34,7 @@ object SparkSQLExample {
case class Person(name: String, age: Long)
// $example off:create_ds$
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
// $example on:init_session$
val spark = SparkSession
.builder()
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala
index a832276602b88..3be8a3862f39c 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala
@@ -28,7 +28,7 @@ object SparkHiveExample {
case class Record(key: Int, value: String)
// $example off:spark_hive$
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
// When working with Hive, one must instantiate `SparkSession` with Hive support, including
// connectivity to a persistent Hive metastore, support for Hive serdes, and Hive user-defined
// functions. Users who do not have an existing Hive deployment can still enable Hive support.
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala
index de477c5ce8161..6dbc70bd141f3 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.SparkSession
* localhost 9999`
*/
object StructuredNetworkWordCount {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 2) {
System.err.println("Usage: StructuredNetworkWordCount ")
System.exit(1)
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala
index b4dad21dd75b0..4ba2c6bc68918 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala
@@ -48,7 +48,7 @@ import org.apache.spark.sql.functions._
*/
object StructuredNetworkWordCountWindowed {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 3) {
System.err.println("Usage: StructuredNetworkWordCountWindowed " +
" []")
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala
index fc3f8fa53c7ae..0f47deaf1021b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala
@@ -37,7 +37,7 @@ import org.apache.spark.streaming.receiver.Receiver
* `$ bin/run-example org.apache.spark.examples.streaming.CustomReceiver localhost 9999`
*/
object CustomReceiver {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 2) {
System.err.println("Usage: CustomReceiver ")
System.exit(1)
@@ -64,20 +64,20 @@ object CustomReceiver {
class CustomReceiver(host: String, port: Int)
extends Receiver[String](StorageLevel.MEMORY_AND_DISK_2) {
- def onStart() {
+ def onStart(): Unit = {
// Start the thread that receives data over a connection
new Thread("Socket Receiver") {
- override def run() { receive() }
+ override def run(): Unit = { receive() }
}.start()
}
- def onStop() {
+ def onStop(): Unit = {
// There is nothing much to do as the thread calling receive()
// is designed to stop by itself isStopped() returns false
}
/** Create a socket connection and receive data until receiver is stopped */
- private def receive() {
+ private def receive(): Unit = {
var socket: Socket = null
var userInput: String = null
try {
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
index 3024b59480099..6fdb37194ea7d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
@@ -37,7 +37,7 @@ import org.apache.spark.streaming.kafka010._
* consumer-group topic1,topic2
*/
object DirectKafkaWordCount {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 3) {
System.err.println(s"""
|Usage: DirectKafkaWordCount
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKerberizedKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKerberizedKafkaWordCount.scala
index b68a59873a8fe..6a35ce9b2a293 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKerberizedKafkaWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKerberizedKafkaWordCount.scala
@@ -76,7 +76,7 @@ import org.apache.spark.streaming.kafka010._
* using SASL_SSL in production.
*/
object DirectKerberizedKafkaWordCount {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 3) {
System.err.println(s"""
|Usage: DirectKerberizedKafkaWordCount
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala
index 1f282d437dc38..19dc7a3cce0ac 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala
@@ -33,7 +33,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext}
* Then create a text file in `localdir` and the words in the file will get counted.
*/
object HdfsWordCount {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 1) {
System.err.println("Usage: HdfsWordCount ")
System.exit(1)
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala
index 15b57fccb4076..26bb51dde3a1d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala
@@ -34,7 +34,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext}
* `$ bin/run-example org.apache.spark.examples.streaming.NetworkWordCount localhost 9999`
*/
object NetworkWordCount {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 2) {
System.err.println("Usage: NetworkWordCount ")
System.exit(1)
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala
index 19bacd449787b..09eeaf9fa4496 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala
@@ -25,7 +25,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext}
object QueueStream {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
StreamingExamples.setStreamingLogLevels()
val sparkConf = new SparkConf().setAppName("QueueStream")
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala
index 437ccf0898d7c..a20abd6e9d12e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala
@@ -37,7 +37,7 @@ import org.apache.spark.util.IntParam
* is the Spark Streaming batch duration in milliseconds.
*/
object RawNetworkGrep {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length != 4) {
System.err.println("Usage: RawNetworkGrep ")
System.exit(1)
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala
index f018f3a26d2e9..243c22e71275c 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala
@@ -139,7 +139,7 @@ object RecoverableNetworkWordCount {
ssc
}
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length != 4) {
System.err.println(s"Your arguments were ${args.mkString("[", ", ", "]")}")
System.err.println(
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala
index 787bbec73b28f..778be7baaeeac 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala
@@ -38,7 +38,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext, Time}
*/
object SqlNetworkWordCount {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 2) {
System.err.println("Usage: NetworkWordCount ")
System.exit(1)
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
index 2811e67009fb0..46f01edf7deec 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
@@ -35,7 +35,7 @@ import org.apache.spark.streaming._
* org.apache.spark.examples.streaming.StatefulNetworkWordCount localhost 9999`
*/
object StatefulNetworkWordCount {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 2) {
System.err.println("Usage: StatefulNetworkWordCount ")
System.exit(1)
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala
index b00f32fb25243..073f9728c68af 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala
@@ -25,7 +25,7 @@ import org.apache.spark.internal.Logging
object StreamingExamples extends Logging {
/** Set reasonable logging levels for streaming if the user has not configured log4j. */
- def setStreamingLogLevels() {
+ def setStreamingLogLevels(): Unit = {
val log4jInitialized = Logger.getRootLogger.getAllAppenders.hasMoreElements
if (!log4jInitialized) {
// We first log something to initialize Spark's default logging, then we override the
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala
index 2108bc63edea2..7234f30e7d267 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala
@@ -81,7 +81,7 @@ object PageViewGenerator {
new PageView(page, status, zipCode, id).toString()
}
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length != 2) {
System.err.println("Usage: PageViewGenerator ")
System.exit(1)
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala
index b8e7c7e9e9152..b51bfacabf4aa 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala
@@ -35,7 +35,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext}
*/
// scalastyle:on
object PageViewStream {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length != 3) {
System.err.println("Usage: PageViewStream ")
System.err.println(" must be one of pageCounts, slidingPageCounts," +
diff --git a/external/avro/benchmarks/AvroReadBenchmark-jdk11-results.txt b/external/avro/benchmarks/AvroReadBenchmark-jdk11-results.txt
new file mode 100644
index 0000000000000..94137a691e4aa
--- /dev/null
+++ b/external/avro/benchmarks/AvroReadBenchmark-jdk11-results.txt
@@ -0,0 +1,122 @@
+================================================================================================
+SQL Single Numeric Column Scan
+================================================================================================
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum 2995 3081 121 5.3 190.4 1.0X
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum 2865 2881 23 5.5 182.2 1.0X
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum 2919 2936 23 5.4 185.6 1.0X
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum 3148 3262 161 5.0 200.1 1.0X
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum 2651 2721 99 5.9 168.5 1.0X
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum 2782 2854 103 5.7 176.9 1.0X
+
+
+================================================================================================
+Int and String Scan
+================================================================================================
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum of columns 4531 4583 73 2.3 432.1 1.0X
+
+
+================================================================================================
+Partitioned Table Scan
+================================================================================================
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Data column 3084 3105 30 5.1 196.1 1.0X
+Partition column 3143 3164 30 5.0 199.8 1.0X
+Both columns 3272 3339 94 4.8 208.1 0.9X
+
+
+================================================================================================
+Repeated String Scan
+================================================================================================
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum of string length 3249 3318 98 3.2 309.8 1.0X
+
+
+================================================================================================
+String with Nulls Scan
+================================================================================================
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum of string length 5308 5335 38 2.0 506.2 1.0X
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum of string length 4405 4429 33 2.4 420.1 1.0X
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum of string length 3256 3309 75 3.2 310.5 1.0X
+
+
+================================================================================================
+Single Column Scan From Wide Columns
+================================================================================================
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum of single column 5230 5290 85 0.2 4987.4 1.0X
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Single Column Scan from 200 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum of single column 10206 10329 174 0.1 9733.1 1.0X
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Single Column Scan from 300 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum of single column 15333 15365 46 0.1 14622.3 1.0X
+
+
diff --git a/external/avro/benchmarks/AvroReadBenchmark-results.txt b/external/avro/benchmarks/AvroReadBenchmark-results.txt
index 7900fea453b10..7b008a312c320 100644
--- a/external/avro/benchmarks/AvroReadBenchmark-results.txt
+++ b/external/avro/benchmarks/AvroReadBenchmark-results.txt
@@ -2,121 +2,121 @@
SQL Single Numeric Column Scan
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Sum 2774 / 2815 5.7 176.4 1.0X
+SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum 3067 3132 91 5.1 195.0 1.0X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Sum 2761 / 2777 5.7 175.5 1.0X
+SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum 2927 2929 3 5.4 186.1 1.0X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Sum 2783 / 2870 5.7 176.9 1.0X
+SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum 2928 2990 87 5.4 186.2 1.0X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Sum 3256 / 3266 4.8 207.0 1.0X
+SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum 3374 3447 104 4.7 214.5 1.0X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Sum 2841 / 2867 5.5 180.6 1.0X
+SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum 2896 2901 7 5.4 184.1 1.0X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Sum 2981 / 2996 5.3 189.5 1.0X
+SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum 3004 3006 3 5.2 191.0 1.0X
================================================================================================
Int and String Scan
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Sum of columns 4781 / 4783 2.2 456.0 1.0X
+Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum of columns 4814 4830 22 2.2 459.1 1.0X
================================================================================================
Partitioned Table Scan
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Data column 3372 / 3386 4.7 214.4 1.0X
-Partition column 3035 / 3064 5.2 193.0 1.1X
-Both columns 3445 / 3461 4.6 219.1 1.0X
+Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Data column 3361 3362 1 4.7 213.7 1.0X
+Partition column 2999 3013 20 5.2 190.7 1.1X
+Both columns 3613 3615 2 4.4 229.7 0.9X
================================================================================================
Repeated String Scan
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Sum of string length 3395 / 3401 3.1 323.8 1.0X
+Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum of string length 3415 3416 1 3.1 325.7 1.0X
================================================================================================
String with Nulls Scan
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Sum of string length 5580 / 5624 1.9 532.2 1.0X
+String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum of string length 5535 5536 2 1.9 527.8 1.0X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-String with Nulls Scan (50.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Sum of string length 4622 / 4623 2.3 440.8 1.0X
+String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum of string length 4567 4575 11 2.3 435.6 1.0X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-String with Nulls Scan (95.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Sum of string length 3238 / 3241 3.2 308.8 1.0X
+String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum of string length 3248 3268 29 3.2 309.7 1.0X
================================================================================================
Single Column Scan From Wide Columns
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Sum of single column 5472 / 5484 0.2 5218.8 1.0X
+Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum of single column 5486 5497 15 0.2 5232.0 1.0X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Sum of single column 10680 / 10701 0.1 10185.1 1.0X
+Single Column Scan from 200 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum of single column 10682 10746 90 0.1 10186.8 1.0X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Sum of single column 16143 / 16238 0.1 15394.9 1.0X
+Single Column Scan from 300 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Sum of single column 16177 16177 0 0.1 15427.7 1.0X
diff --git a/external/avro/benchmarks/AvroWriteBenchmark-jdk11-results.txt b/external/avro/benchmarks/AvroWriteBenchmark-jdk11-results.txt
new file mode 100644
index 0000000000000..2cf1835013821
--- /dev/null
+++ b/external/avro/benchmarks/AvroWriteBenchmark-jdk11-results.txt
@@ -0,0 +1,10 @@
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Avro writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Output Single Int Column 3026 3142 164 5.2 192.4 1.0X
+Output Single Double Column 3157 3260 145 5.0 200.7 1.0X
+Output Int and String Column 6123 6190 94 2.6 389.3 0.5X
+Output Partitions 5197 5733 758 3.0 330.4 0.6X
+Output Buckets 7074 7285 298 2.2 449.7 0.4X
+
diff --git a/external/avro/benchmarks/AvroWriteBenchmark-results.txt b/external/avro/benchmarks/AvroWriteBenchmark-results.txt
index fb2a77333eec5..20f6ae9099a4d 100644
--- a/external/avro/benchmarks/AvroWriteBenchmark-results.txt
+++ b/external/avro/benchmarks/AvroWriteBenchmark-results.txt
@@ -1,10 +1,10 @@
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-Avro writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Output Single Int Column 3213 / 3373 4.9 204.3 1.0X
-Output Single Double Column 3313 / 3345 4.7 210.7 1.0X
-Output Int and String Column 7303 / 7316 2.2 464.3 0.4X
-Output Partitions 5309 / 5691 3.0 337.5 0.6X
-Output Buckets 7031 / 7557 2.2 447.0 0.5X
+Avro writer benchmark: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Output Single Int Column 3080 3137 82 5.1 195.8 1.0X
+Output Single Double Column 3595 3595 0 4.4 228.6 0.9X
+Output Int and String Column 7491 7504 18 2.1 476.3 0.4X
+Output Partitions 5518 5663 205 2.9 350.8 0.6X
+Output Buckets 7467 7581 161 2.1 474.7 0.4X
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala
index 3171f1e08b4fc..c6f52d676422c 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala
@@ -17,9 +17,9 @@
package org.apache.spark.sql.v2.avro
import org.apache.spark.sql.avro.AvroFileFormat
+import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
-import org.apache.spark.sql.sources.v2.Table
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala
index 243af7da47003..0397d15aed924 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala
@@ -31,10 +31,10 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.read.PartitionReader
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.datasources.v2.{EmptyPartitionReader, FilePartitionReaderFactory, PartitionReaderWithPartitionValues}
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.v2.reader.PartitionReader
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala
index 6ec351080a118..e1268ac2ce581 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala
@@ -21,9 +21,9 @@ import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScan
-import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala
index 815da2bd92d44..e36c71ef4b1f7 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala
@@ -17,9 +17,9 @@
package org.apache.spark.sql.v2.avro
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
-import org.apache.spark.sql.sources.v2.reader.Scan
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala
index a781624aa61aa..765e5727d944a 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala
@@ -22,9 +22,9 @@ import org.apache.hadoop.fs.FileStatus
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.avro.AvroUtils
+import org.apache.spark.sql.connector.write.WriteBuilder
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.v2.FileTable
-import org.apache.spark.sql.sources.v2.writer.WriteBuilder
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index cf88981b1efbd..dc60cfe41ca7a 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -1036,7 +1036,7 @@ abstract class AvroSuite extends QueryTest with SharedSparkSession {
(TimestampType, LONG),
(DecimalType(4, 2), BYTES)
)
- def assertException(f: () => AvroSerializer) {
+ def assertException(f: () => AvroSerializer): Unit = {
val message = intercept[org.apache.spark.sql.avro.IncompatibleSchemaException] {
f()
}.getMessage
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala b/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala
index f2f7d650066fb..a16126ae24246 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala
+++ b/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala
@@ -22,7 +22,6 @@ import scala.util.Random
import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.types._
/**
@@ -36,7 +35,7 @@ import org.apache.spark.sql.types._
* Results will be written to "benchmarks/AvroReadBenchmark-results.txt".
* }}}
*/
-object AvroReadBenchmark extends SqlBasedBenchmark with SQLHelper {
+object AvroReadBenchmark extends SqlBasedBenchmark {
def withTempTable(tableNames: String*)(f: => Unit): Unit = {
try f finally tableNames.foreach(spark.catalog.dropTempView)
}
diff --git a/external/docker/spark-test/base/Dockerfile b/external/docker/spark-test/base/Dockerfile
index c1fd630d0b665..5bec5d3f16548 100644
--- a/external/docker/spark-test/base/Dockerfile
+++ b/external/docker/spark-test/base/Dockerfile
@@ -25,7 +25,7 @@ RUN apt-get update && \
apt-get install -y less openjdk-8-jre-headless iproute2 vim-tiny sudo openssh-server && \
rm -rf /var/lib/apt/lists/*
-ENV SCALA_VERSION 2.12.8
+ENV SCALA_VERSION 2.12.10
ENV CDH_VERSION cdh4
ENV SCALA_HOME /opt/scala-$SCALA_VERSION
ENV SPARK_HOME /opt/spark
diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml
index 0735f0a7b937f..693820da6af6b 100644
--- a/external/kafka-0-10-sql/pom.xml
+++ b/external/kafka-0-10-sql/pom.xml
@@ -46,6 +46,13 @@
${project.version}
provided
+
+ org.apache.spark
+ spark-token-provider-kafka-0-10_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
org.apache.spark
spark-core_${scala.binary.version}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala
index 868edb5dcdc0c..6dd5af2389a81 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala
@@ -68,7 +68,7 @@ private object JsonUtils {
partOffsets.map { case (part, offset) =>
new TopicPartition(topic, part) -> offset
}
- }.toMap
+ }
} catch {
case NonFatal(x) =>
throw new IllegalArgumentException(
@@ -76,12 +76,27 @@ private object JsonUtils {
}
}
+ def partitionTimestamps(str: String): Map[TopicPartition, Long] = {
+ try {
+ Serialization.read[Map[String, Map[Int, Long]]](str).flatMap { case (topic, partTimestamps) =>
+ partTimestamps.map { case (part, timestamp) =>
+ new TopicPartition(topic, part) -> timestamp
+ }
+ }
+ } catch {
+ case NonFatal(x) =>
+ throw new IllegalArgumentException(
+ s"""Expected e.g. {"topicA": {"0": 123456789, "1": 123456789},
+ |"topicB": {"0": 123456789, "1": 123456789}}, got $str""".stripMargin)
+ }
+ }
+
/**
* Write per-TopicPartition offsets as json string
*/
def partitionOffsets(partitionOffsets: Map[TopicPartition, Long]): String = {
val result = new HashMap[String, HashMap[Int, Long]]()
- implicit val ordering = new Ordering[TopicPartition] {
+ implicit val order = new Ordering[TopicPartition] {
override def compare(x: TopicPartition, y: TopicPartition): Int = {
Ordering.Tuple2[String, Int].compare((x.topic, x.partition), (y.topic, y.partition))
}
@@ -95,4 +110,9 @@ private object JsonUtils {
}
Serialization.write(result)
}
+
+ def partitionTimestamps(topicTimestamps: Map[TopicPartition, Long]): String = {
+ // For now it's same as partitionOffsets
+ partitionOffsets(topicTimestamps)
+ }
}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala
index 700414167f3ef..3006770f306c0 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala
@@ -23,8 +23,7 @@ import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
-import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReaderFactory}
-
+import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory}
private[kafka010] class KafkaBatch(
strategy: ConsumerStrategy,
@@ -32,7 +31,8 @@ private[kafka010] class KafkaBatch(
specifiedKafkaParams: Map[String, String],
failOnDataLoss: Boolean,
startingOffsets: KafkaOffsetRangeLimit,
- endingOffsets: KafkaOffsetRangeLimit)
+ endingOffsets: KafkaOffsetRangeLimit,
+ includeHeaders: Boolean)
extends Batch with Logging {
assert(startingOffsets != LatestOffsetRangeLimit,
"Starting offset not allowed to be set to latest offsets.")
@@ -59,8 +59,8 @@ private[kafka010] class KafkaBatch(
// Leverage the KafkaReader to obtain the relevant partition offsets
val (fromPartitionOffsets, untilPartitionOffsets) = {
try {
- (kafkaOffsetReader.fetchPartitionOffsets(startingOffsets),
- kafkaOffsetReader.fetchPartitionOffsets(endingOffsets))
+ (kafkaOffsetReader.fetchPartitionOffsets(startingOffsets, isStartingOffsets = true),
+ kafkaOffsetReader.fetchPartitionOffsets(endingOffsets, isStartingOffsets = false))
} finally {
kafkaOffsetReader.close()
}
@@ -91,7 +91,7 @@ private[kafka010] class KafkaBatch(
KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId)
offsetRanges.map { range =>
new KafkaBatchInputPartition(
- range, executorKafkaParams, pollTimeoutMs, failOnDataLoss)
+ range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, includeHeaders)
}.toArray
}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala
index 53b0b3c46854e..645b68b0c407a 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala
@@ -22,21 +22,21 @@ import java.{util => ju}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.sources.v2.reader._
-
+import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
/** A [[InputPartition]] for reading Kafka data in a batch based streaming query. */
private[kafka010] case class KafkaBatchInputPartition(
offsetRange: KafkaOffsetRange,
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
- failOnDataLoss: Boolean) extends InputPartition
+ failOnDataLoss: Boolean,
+ includeHeaders: Boolean) extends InputPartition
private[kafka010] object KafkaBatchReaderFactory extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val p = partition.asInstanceOf[KafkaBatchInputPartition]
KafkaBatchPartitionReader(p.offsetRange, p.executorKafkaParams, p.pollTimeoutMs,
- p.failOnDataLoss)
+ p.failOnDataLoss, p.includeHeaders)
}
}
@@ -45,12 +45,14 @@ private case class KafkaBatchPartitionReader(
offsetRange: KafkaOffsetRange,
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
- failOnDataLoss: Boolean) extends PartitionReader[InternalRow] with Logging {
+ failOnDataLoss: Boolean,
+ includeHeaders: Boolean) extends PartitionReader[InternalRow] with Logging {
private val consumer = KafkaDataConsumer.acquire(offsetRange.topicPartition, executorKafkaParams)
private val rangeToRead = resolveRange(offsetRange)
- private val converter = new KafkaRecordToUnsafeRowConverter
+ private val unsafeRowProjector = new KafkaRecordToRowConverter()
+ .toUnsafeRowProjector(includeHeaders)
private var nextOffset = rangeToRead.fromOffset
private var nextRow: UnsafeRow = _
@@ -59,7 +61,7 @@ private case class KafkaBatchPartitionReader(
if (nextOffset < rangeToRead.untilOffset) {
val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss)
if (record != null) {
- nextRow = converter.toUnsafeRow(record)
+ nextRow = unsafeRowProjector(record)
nextOffset = record.offset + 1
true
} else {
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala
index 47ec07ae128d2..8e29e38b2a644 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.kafka010
import java.{util => ju}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery
-import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.types.StructType
/**
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala
index a9c1181a01c51..0603ae39ba622 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala
@@ -27,9 +27,9 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
-import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming._
+import org.apache.spark.sql.connector.read.InputPartition
+import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader, ContinuousPartitionReaderFactory, ContinuousStream, Offset, PartitionOffset}
+import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
@@ -56,6 +56,7 @@ class KafkaContinuousStream(
private[kafka010] val pollTimeoutMs =
options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512)
+ private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)
// Initialized when creating reader factories. If this diverges from the partitions at the latest
// offsets, we need to reconfigure.
@@ -68,6 +69,8 @@ class KafkaContinuousStream(
case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets())
case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets(None))
case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss)
+ case SpecificTimestampRangeLimit(p) => offsetReader.fetchSpecificTimestampBasedOffsets(p,
+ failsOnNoMatchingOffset = true)
}
logInfo(s"Initial offsets: $offsets")
offsets
@@ -88,7 +91,7 @@ class KafkaContinuousStream(
if (deletedPartitions.nonEmpty) {
val message = if (
offsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
- s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}"
+ s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}"
} else {
s"$deletedPartitions are gone. Some data may have been missed."
}
@@ -102,7 +105,7 @@ class KafkaContinuousStream(
startOffsets.toSeq.map {
case (topicPartition, start) =>
KafkaContinuousInputPartition(
- topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
+ topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss, includeHeaders)
}.toArray
}
@@ -153,19 +156,22 @@ class KafkaContinuousStream(
* @param pollTimeoutMs The timeout for Kafka consumer polling.
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
* are skipped.
+ * @param includeHeaders Flag indicating whether to include Kafka records' headers.
*/
case class KafkaContinuousInputPartition(
- topicPartition: TopicPartition,
- startOffset: Long,
- kafkaParams: ju.Map[String, Object],
- pollTimeoutMs: Long,
- failOnDataLoss: Boolean) extends InputPartition
+ topicPartition: TopicPartition,
+ startOffset: Long,
+ kafkaParams: ju.Map[String, Object],
+ pollTimeoutMs: Long,
+ failOnDataLoss: Boolean,
+ includeHeaders: Boolean) extends InputPartition
object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory {
override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = {
val p = partition.asInstanceOf[KafkaContinuousInputPartition]
new KafkaContinuousPartitionReader(
- p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs, p.failOnDataLoss)
+ p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs,
+ p.failOnDataLoss, p.includeHeaders)
}
}
@@ -184,9 +190,11 @@ class KafkaContinuousPartitionReader(
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
- failOnDataLoss: Boolean) extends ContinuousPartitionReader[InternalRow] {
+ failOnDataLoss: Boolean,
+ includeHeaders: Boolean) extends ContinuousPartitionReader[InternalRow] {
private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams)
- private val converter = new KafkaRecordToUnsafeRowConverter
+ private val unsafeRowProjector = new KafkaRecordToRowConverter()
+ .toUnsafeRowProjector(includeHeaders)
private var nextKafkaOffset = startOffset
private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _
@@ -225,7 +233,7 @@ class KafkaContinuousPartitionReader(
}
override def get(): UnsafeRow = {
- converter.toUnsafeRow(currentRecord)
+ unsafeRowProjector(currentRecord)
}
override def getOffset(): KafkaSourcePartitionOffset = {
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
index 87036beb9a252..ca82c908f441b 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
@@ -23,12 +23,13 @@ import java.util.concurrent.TimeoutException
import scala.collection.JavaConverters._
+import org.apache.kafka.clients.CommonClientConfigs
import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer, OffsetOutOfRangeException}
import org.apache.kafka.common.TopicPartition
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.internal.Logging
-import org.apache.spark.kafka010.KafkaConfigUpdater
+import org.apache.spark.kafka010.{KafkaConfigUpdater, KafkaTokenClusterConf, KafkaTokenUtil}
import org.apache.spark.sql.kafka010.KafkaDataConsumer.{AvailableOffsetRange, UNKNOWN_OFFSET}
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.util.{ShutdownHookManager, UninterruptibleThread}
@@ -46,6 +47,13 @@ private[kafka010] class InternalKafkaConsumer(
val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
+ private[kafka010] val clusterConfig = KafkaTokenUtil.findMatchingTokenClusterConfig(
+ SparkEnv.get.conf, kafkaParams.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG)
+ .asInstanceOf[String])
+
+ // Kafka consumer is not able to give back the params instantiated with so we need to store it.
+ // It must be updated whenever a new consumer is created.
+ private[kafka010] var kafkaParamsWithSecurity: ju.Map[String, Object] = _
private val consumer = createConsumer()
/**
@@ -106,10 +114,10 @@ private[kafka010] class InternalKafkaConsumer(
/** Create a KafkaConsumer to fetch records for `topicPartition` */
private def createConsumer(): KafkaConsumer[Array[Byte], Array[Byte]] = {
- val updatedKafkaParams = KafkaConfigUpdater("executor", kafkaParams.asScala.toMap)
- .setAuthenticationConfigIfNeeded()
+ kafkaParamsWithSecurity = KafkaConfigUpdater("executor", kafkaParams.asScala.toMap)
+ .setAuthenticationConfigIfNeeded(clusterConfig)
.build()
- val c = new KafkaConsumer[Array[Byte], Array[Byte]](updatedKafkaParams)
+ val c = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParamsWithSecurity)
val tps = new ju.ArrayList[TopicPartition]()
tps.add(topicPartition)
c.assign(tps)
@@ -516,13 +524,25 @@ private[kafka010] class KafkaDataConsumer(
fetchedData.withNewPoll(records.listIterator, offsetAfterPoll)
}
- private def getOrRetrieveConsumer(): InternalKafkaConsumer = _consumer match {
- case None =>
- _consumer = Option(consumerPool.borrowObject(cacheKey, kafkaParams))
- require(_consumer.isDefined, "borrowing consumer from pool must always succeed.")
- _consumer.get
+ private[kafka010] def getOrRetrieveConsumer(): InternalKafkaConsumer = {
+ if (!_consumer.isDefined) {
+ retrieveConsumer()
+ }
+ require(_consumer.isDefined, "Consumer must be defined")
+ if (!KafkaTokenUtil.isConnectorUsingCurrentToken(_consumer.get.kafkaParamsWithSecurity,
+ _consumer.get.clusterConfig)) {
+ logDebug("Cached consumer uses an old delegation token, invalidating.")
+ releaseConsumer()
+ consumerPool.invalidateKey(cacheKey)
+ fetchedDataPool.invalidate(cacheKey)
+ retrieveConsumer()
+ }
+ _consumer.get
+ }
- case Some(consumer) => consumer
+ private def retrieveConsumer(): Unit = {
+ _consumer = Option(consumerPool.borrowObject(cacheKey, kafkaParams))
+ require(_consumer.isDefined, "borrowing consumer from pool must always succeed.")
}
private def getOrRetrieveFetchedData(offset: Long): FetchedData = _fetchedData match {
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala
index 884773452b2a5..3f8d3d2da5797 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala
@@ -21,7 +21,7 @@ import java.{util => ju}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.sources.v2.writer._
+import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
/**
* Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala
index 9cd16c8e16249..01f6ba4445162 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala
@@ -26,10 +26,10 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
+import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset}
import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchStream
-import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
-import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset}
+import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.UninterruptibleThread
@@ -64,6 +64,8 @@ private[kafka010] class KafkaMicroBatchStream(
private[kafka010] val maxOffsetsPerTrigger = Option(options.get(
KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER)).map(_.toLong)
+ private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)
+
private val rangeCalculator = KafkaOffsetRangeCalculator(options)
private var endPartitionOffsets: KafkaSourceOffset = _
@@ -112,7 +114,7 @@ private[kafka010] class KafkaMicroBatchStream(
if (deletedPartitions.nonEmpty) {
val message =
if (kafkaOffsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
- s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}"
+ s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}"
} else {
s"$deletedPartitions are gone. Some data may have been missed."
}
@@ -146,7 +148,8 @@ private[kafka010] class KafkaMicroBatchStream(
// Generate factories based on the offset ranges
offsetRanges.map { range =>
- KafkaBatchInputPartition(range, executorKafkaParams, pollTimeoutMs, failOnDataLoss)
+ KafkaBatchInputPartition(range, executorKafkaParams, pollTimeoutMs,
+ failOnDataLoss, includeHeaders)
}.toArray
}
@@ -189,6 +192,8 @@ private[kafka010] class KafkaMicroBatchStream(
KafkaSourceOffset(kafkaOffsetReader.fetchLatestOffsets(None))
case SpecificOffsetRangeLimit(p) =>
kafkaOffsetReader.fetchSpecificOffsets(p, reportDataLoss)
+ case SpecificTimestampRangeLimit(p) =>
+ kafkaOffsetReader.fetchSpecificTimestampBasedOffsets(p, failsOnNoMatchingOffset = true)
}
metadataLog.add(0, offsets)
logInfo(s"Initial offsets: $offsets")
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeLimit.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeLimit.scala
index 80a026f4f5d73..d64b5d4f7e9e8 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeLimit.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeLimit.scala
@@ -42,6 +42,13 @@ private[kafka010] case object LatestOffsetRangeLimit extends KafkaOffsetRangeLim
private[kafka010] case class SpecificOffsetRangeLimit(
partitionOffsets: Map[TopicPartition, Long]) extends KafkaOffsetRangeLimit
+/**
+ * Represents the desire to bind to earliest offset which timestamp for the offset is equal or
+ * greater than specific timestamp.
+ */
+private[kafka010] case class SpecificTimestampRangeLimit(
+ topicTimestamps: Map[TopicPartition, Long]) extends KafkaOffsetRangeLimit
+
private[kafka010] object KafkaOffsetRangeLimit {
/**
* Used to denote offset range limits that are resolved via Kafka
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
index f3effd5300a79..0179f4dd822f1 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
@@ -26,12 +26,11 @@ import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.util.control.NonFatal
-import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer}
+import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer, OffsetAndTimestamp}
import org.apache.kafka.common.TopicPartition
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
-import org.apache.spark.sql.types._
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
/**
@@ -127,12 +126,14 @@ private[kafka010] class KafkaOffsetReader(
* Fetch the partition offsets for the topic partitions that are indicated
* in the [[ConsumerStrategy]] and [[KafkaOffsetRangeLimit]].
*/
- def fetchPartitionOffsets(offsetRangeLimit: KafkaOffsetRangeLimit): Map[TopicPartition, Long] = {
+ def fetchPartitionOffsets(
+ offsetRangeLimit: KafkaOffsetRangeLimit,
+ isStartingOffsets: Boolean): Map[TopicPartition, Long] = {
def validateTopicPartitions(partitions: Set[TopicPartition],
partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = {
assert(partitions == partitionOffsets.keySet,
"If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" +
- "Use -1 for latest, -2 for earliest, if you don't care.\n" +
+ "Use -1 for latest, -2 for earliest.\n" +
s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions}")
logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets")
partitionOffsets
@@ -148,6 +149,9 @@ private[kafka010] class KafkaOffsetReader(
}.toMap
case SpecificOffsetRangeLimit(partitionOffsets) =>
validateTopicPartitions(partitions, partitionOffsets)
+ case SpecificTimestampRangeLimit(partitionTimestamps) =>
+ fetchSpecificTimestampBasedOffsets(partitionTimestamps,
+ failsOnNoMatchingOffset = isStartingOffsets).partitionToOffsets
}
}
@@ -162,23 +166,83 @@ private[kafka010] class KafkaOffsetReader(
def fetchSpecificOffsets(
partitionOffsets: Map[TopicPartition, Long],
reportDataLoss: String => Unit): KafkaSourceOffset = {
- val fetched = runUninterruptibly {
- withRetriesWithoutInterrupt {
- // Poll to get the latest assigned partitions
- consumer.poll(0)
- val partitions = consumer.assignment()
+ val fnAssertParametersWithPartitions: ju.Set[TopicPartition] => Unit = { partitions =>
+ assert(partitions.asScala == partitionOffsets.keySet,
+ "If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" +
+ "Use -1 for latest, -2 for earliest, if you don't care.\n" +
+ s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions.asScala}")
+ logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets")
+ }
- // Call `position` to wait until the potential offset request triggered by `poll(0)` is
- // done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by
- // `poll(0)` may reset offsets that should have been set by another request.
- partitions.asScala.map(p => p -> consumer.position(p)).foreach(_ => {})
+ val fnRetrievePartitionOffsets: ju.Set[TopicPartition] => Map[TopicPartition, Long] = { _ =>
+ partitionOffsets
+ }
- consumer.pause(partitions)
- assert(partitions.asScala == partitionOffsets.keySet,
- "If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" +
- "Use -1 for latest, -2 for earliest, if you don't care.\n" +
- s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions.asScala}")
- logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets")
+ val fnAssertFetchedOffsets: Map[TopicPartition, Long] => Unit = { fetched =>
+ partitionOffsets.foreach {
+ case (tp, off) if off != KafkaOffsetRangeLimit.LATEST &&
+ off != KafkaOffsetRangeLimit.EARLIEST =>
+ if (fetched(tp) != off) {
+ reportDataLoss(
+ s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}")
+ }
+ case _ =>
+ // no real way to check that beginning or end is reasonable
+ }
+ }
+
+ fetchSpecificOffsets0(fnAssertParametersWithPartitions, fnRetrievePartitionOffsets,
+ fnAssertFetchedOffsets)
+ }
+
+ def fetchSpecificTimestampBasedOffsets(
+ partitionTimestamps: Map[TopicPartition, Long],
+ failsOnNoMatchingOffset: Boolean): KafkaSourceOffset = {
+ val fnAssertParametersWithPartitions: ju.Set[TopicPartition] => Unit = { partitions =>
+ assert(partitions.asScala == partitionTimestamps.keySet,
+ "If starting/endingOffsetsByTimestamp contains specific offsets, you must specify all " +
+ s"topics. Specified: ${partitionTimestamps.keySet} Assigned: ${partitions.asScala}")
+ logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionTimestamps")
+ }
+
+ val fnRetrievePartitionOffsets: ju.Set[TopicPartition] => Map[TopicPartition, Long] = { _ => {
+ val converted = partitionTimestamps.map { case (tp, timestamp) =>
+ tp -> java.lang.Long.valueOf(timestamp)
+ }.asJava
+
+ val offsetForTime: ju.Map[TopicPartition, OffsetAndTimestamp] =
+ consumer.offsetsForTimes(converted)
+
+ offsetForTime.asScala.map { case (tp, offsetAndTimestamp) =>
+ if (failsOnNoMatchingOffset) {
+ assert(offsetAndTimestamp != null, "No offset matched from request of " +
+ s"topic-partition $tp and timestamp ${partitionTimestamps(tp)}.")
+ }
+
+ if (offsetAndTimestamp == null) {
+ tp -> KafkaOffsetRangeLimit.LATEST
+ } else {
+ tp -> offsetAndTimestamp.offset()
+ }
+ }.toMap
+ }
+ }
+
+ val fnAssertFetchedOffsets: Map[TopicPartition, Long] => Unit = { _ => }
+
+ fetchSpecificOffsets0(fnAssertParametersWithPartitions, fnRetrievePartitionOffsets,
+ fnAssertFetchedOffsets)
+ }
+
+ private def fetchSpecificOffsets0(
+ fnAssertParametersWithPartitions: ju.Set[TopicPartition] => Unit,
+ fnRetrievePartitionOffsets: ju.Set[TopicPartition] => Map[TopicPartition, Long],
+ fnAssertFetchedOffsets: Map[TopicPartition, Long] => Unit): KafkaSourceOffset = {
+ val fetched = partitionsAssignedToConsumer {
+ partitions => {
+ fnAssertParametersWithPartitions(partitions)
+
+ val partitionOffsets = fnRetrievePartitionOffsets(partitions)
partitionOffsets.foreach {
case (tp, KafkaOffsetRangeLimit.LATEST) =>
@@ -187,22 +251,15 @@ private[kafka010] class KafkaOffsetReader(
consumer.seekToBeginning(ju.Arrays.asList(tp))
case (tp, off) => consumer.seek(tp, off)
}
+
partitionOffsets.map {
case (tp, _) => tp -> consumer.position(tp)
}
}
}
- partitionOffsets.foreach {
- case (tp, off) if off != KafkaOffsetRangeLimit.LATEST &&
- off != KafkaOffsetRangeLimit.EARLIEST =>
- if (fetched(tp) != off) {
- reportDataLoss(
- s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}")
- }
- case _ =>
- // no real way to check that beginning or end is reasonable
- }
+ fnAssertFetchedOffsets(fetched)
+
KafkaSourceOffset(fetched)
}
@@ -210,20 +267,15 @@ private[kafka010] class KafkaOffsetReader(
* Fetch the earliest offsets for the topic partitions that are indicated
* in the [[ConsumerStrategy]].
*/
- def fetchEarliestOffsets(): Map[TopicPartition, Long] = runUninterruptibly {
- withRetriesWithoutInterrupt {
- // Poll to get the latest assigned partitions
- consumer.poll(0)
- val partitions = consumer.assignment()
- consumer.pause(partitions)
- logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the beginning")
+ def fetchEarliestOffsets(): Map[TopicPartition, Long] = partitionsAssignedToConsumer(
+ partitions => {
+ logDebug("Seeking to the beginning")
consumer.seekToBeginning(partitions)
val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap
logDebug(s"Got earliest offsets for partition : $partitionOffsets")
partitionOffsets
- }
- }
+ }, fetchingEarliestOffset = true)
/**
* Fetch the latest offsets for the topic partitions that are indicated
@@ -240,19 +292,9 @@ private[kafka010] class KafkaOffsetReader(
* distinguish this with KAFKA-7703, so we just return whatever we get from Kafka after retrying.
*/
def fetchLatestOffsets(
- knownOffsets: Option[PartitionOffsetMap]): PartitionOffsetMap = runUninterruptibly {
- withRetriesWithoutInterrupt {
- // Poll to get the latest assigned partitions
- consumer.poll(0)
- val partitions = consumer.assignment()
-
- // Call `position` to wait until the potential offset request triggered by `poll(0)` is
- // done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by
- // `poll(0)` may reset offsets that should have been set by another request.
- partitions.asScala.map(p => p -> consumer.position(p)).foreach(_ => {})
-
- consumer.pause(partitions)
- logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the end.")
+ knownOffsets: Option[PartitionOffsetMap]): PartitionOffsetMap =
+ partitionsAssignedToConsumer { partitions => {
+ logDebug("Seeking to the end.")
if (knownOffsets.isEmpty) {
consumer.seekToEnd(partitions)
@@ -316,25 +358,40 @@ private[kafka010] class KafkaOffsetReader(
if (newPartitions.isEmpty) {
Map.empty[TopicPartition, Long]
} else {
- runUninterruptibly {
- withRetriesWithoutInterrupt {
- // Poll to get the latest assigned partitions
- consumer.poll(0)
- val partitions = consumer.assignment()
- consumer.pause(partitions)
- logDebug(s"\tPartitions assigned to consumer: $partitions")
-
- // Get the earliest offset of each partition
- consumer.seekToBeginning(partitions)
- val partitionOffsets = newPartitions.filter { p =>
- // When deleting topics happen at the same time, some partitions may not be in
- // `partitions`. So we need to ignore them
- partitions.contains(p)
- }.map(p => p -> consumer.position(p)).toMap
- logDebug(s"Got earliest offsets for new partitions: $partitionOffsets")
- partitionOffsets
- }
+ partitionsAssignedToConsumer(partitions => {
+ // Get the earliest offset of each partition
+ consumer.seekToBeginning(partitions)
+ val partitionOffsets = newPartitions.filter { p =>
+ // When deleting topics happen at the same time, some partitions may not be in
+ // `partitions`. So we need to ignore them
+ partitions.contains(p)
+ }.map(p => p -> consumer.position(p)).toMap
+ logDebug(s"Got earliest offsets for new partitions: $partitionOffsets")
+ partitionOffsets
+ }, fetchingEarliestOffset = true)
+ }
+ }
+
+ private def partitionsAssignedToConsumer(
+ body: ju.Set[TopicPartition] => Map[TopicPartition, Long],
+ fetchingEarliestOffset: Boolean = false)
+ : Map[TopicPartition, Long] = runUninterruptibly {
+
+ withRetriesWithoutInterrupt {
+ // Poll to get the latest assigned partitions
+ consumer.poll(0)
+ val partitions = consumer.assignment()
+
+ if (!fetchingEarliestOffset) {
+ // Call `position` to wait until the potential offset request triggered by `poll(0)` is
+ // done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by
+ // `poll(0)` may reset offsets that should have been set by another request.
+ partitions.asScala.map(p => p -> consumer.position(p)).foreach(_ => {})
}
+
+ consumer.pause(partitions)
+ logDebug(s"Partitions assigned to consumer: $partitions.")
+ body(partitions)
}
}
@@ -421,16 +478,3 @@ private[kafka010] class KafkaOffsetReader(
_consumer = null // will automatically get reinitialized again
}
}
-
-private[kafka010] object KafkaOffsetReader {
-
- def kafkaSchema: StructType = StructType(Seq(
- StructField("key", BinaryType),
- StructField("value", BinaryType),
- StructField("topic", StringType),
- StructField("partition", IntegerType),
- StructField("offset", LongType),
- StructField("timestamp", TimestampType),
- StructField("timestampType", IntegerType)
- ))
-}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala
new file mode 100644
index 0000000000000..aed099c142bc3
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.sql.Timestamp
+
+import scala.collection.JavaConverters._
+
+import org.apache.kafka.clients.consumer.ConsumerRecord
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/** A simple class for converting Kafka ConsumerRecord to InternalRow/UnsafeRow */
+private[kafka010] class KafkaRecordToRowConverter {
+ import KafkaRecordToRowConverter._
+
+ private val toUnsafeRowWithoutHeaders = UnsafeProjection.create(schemaWithoutHeaders)
+ private val toUnsafeRowWithHeaders = UnsafeProjection.create(schemaWithHeaders)
+
+ val toInternalRowWithoutHeaders: Record => InternalRow =
+ (cr: Record) => InternalRow(
+ cr.key, cr.value, UTF8String.fromString(cr.topic), cr.partition, cr.offset,
+ DateTimeUtils.fromJavaTimestamp(new Timestamp(cr.timestamp)), cr.timestampType.id
+ )
+
+ val toInternalRowWithHeaders: Record => InternalRow =
+ (cr: Record) => InternalRow(
+ cr.key, cr.value, UTF8String.fromString(cr.topic), cr.partition, cr.offset,
+ DateTimeUtils.fromJavaTimestamp(new Timestamp(cr.timestamp)), cr.timestampType.id,
+ if (cr.headers.iterator().hasNext) {
+ new GenericArrayData(cr.headers.iterator().asScala
+ .map(header =>
+ InternalRow(UTF8String.fromString(header.key()), header.value())
+ ).toArray)
+ } else {
+ null
+ }
+ )
+
+ def toUnsafeRowWithoutHeadersProjector: Record => UnsafeRow =
+ (cr: Record) => toUnsafeRowWithoutHeaders(toInternalRowWithoutHeaders(cr))
+
+ def toUnsafeRowWithHeadersProjector: Record => UnsafeRow =
+ (cr: Record) => toUnsafeRowWithHeaders(toInternalRowWithHeaders(cr))
+
+ def toUnsafeRowProjector(includeHeaders: Boolean): Record => UnsafeRow = {
+ if (includeHeaders) toUnsafeRowWithHeadersProjector else toUnsafeRowWithoutHeadersProjector
+ }
+}
+
+private[kafka010] object KafkaRecordToRowConverter {
+ type Record = ConsumerRecord[Array[Byte], Array[Byte]]
+
+ val headersType = ArrayType(StructType(Array(
+ StructField("key", StringType),
+ StructField("value", BinaryType))))
+
+ private val schemaWithoutHeaders = new StructType(Array(
+ StructField("key", BinaryType),
+ StructField("value", BinaryType),
+ StructField("topic", StringType),
+ StructField("partition", IntegerType),
+ StructField("offset", LongType),
+ StructField("timestamp", TimestampType),
+ StructField("timestampType", IntegerType)
+ ))
+
+ private val schemaWithHeaders =
+ new StructType(schemaWithoutHeaders.fields :+ StructField("headers", headersType))
+
+ def kafkaSchema(includeHeaders: Boolean): StructType = {
+ if (includeHeaders) schemaWithHeaders else schemaWithoutHeaders
+ }
+}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala
deleted file mode 100644
index 306ef10b775a9..0000000000000
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.kafka010
-
-import org.apache.kafka.clients.consumer.ConsumerRecord
-
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.unsafe.types.UTF8String
-
-/** A simple class for converting Kafka ConsumerRecord to UnsafeRow */
-private[kafka010] class KafkaRecordToUnsafeRowConverter {
- private val rowWriter = new UnsafeRowWriter(7)
-
- def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = {
- rowWriter.reset()
- rowWriter.zeroOutNullBytes()
-
- if (record.key == null) {
- rowWriter.setNullAt(0)
- } else {
- rowWriter.write(0, record.key)
- }
- if (record.value == null) {
- rowWriter.setNullAt(1)
- } else {
- rowWriter.write(1, record.value)
- }
- rowWriter.write(2, UTF8String.fromString(record.topic))
- rowWriter.write(3, record.partition)
- rowWriter.write(4, record.offset)
- rowWriter.write(
- 5,
- DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp)))
- rowWriter.write(6, record.timestampType.id)
- rowWriter.getRow()
- }
-}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala
index dc7087821b10c..61479c992039b 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala
@@ -24,10 +24,9 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.StructType
-import org.apache.spark.unsafe.types.UTF8String
private[kafka010] class KafkaRelation(
@@ -36,6 +35,7 @@ private[kafka010] class KafkaRelation(
sourceOptions: CaseInsensitiveMap[String],
specifiedKafkaParams: Map[String, String],
failOnDataLoss: Boolean,
+ includeHeaders: Boolean,
startingOffsets: KafkaOffsetRangeLimit,
endingOffsets: KafkaOffsetRangeLimit)
extends BaseRelation with TableScan with Logging {
@@ -49,7 +49,9 @@ private[kafka010] class KafkaRelation(
(sqlContext.sparkContext.conf.get(NETWORK_TIMEOUT) * 1000L).toString
).toLong
- override def schema: StructType = KafkaOffsetReader.kafkaSchema
+ private val converter = new KafkaRecordToRowConverter()
+
+ override def schema: StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders)
override def buildScan(): RDD[Row] = {
// Each running query should use its own group id. Otherwise, the query may be only assigned
@@ -66,8 +68,8 @@ private[kafka010] class KafkaRelation(
// Leverage the KafkaReader to obtain the relevant partition offsets
val (fromPartitionOffsets, untilPartitionOffsets) = {
try {
- (kafkaOffsetReader.fetchPartitionOffsets(startingOffsets),
- kafkaOffsetReader.fetchPartitionOffsets(endingOffsets))
+ (kafkaOffsetReader.fetchPartitionOffsets(startingOffsets, isStartingOffsets = true),
+ kafkaOffsetReader.fetchPartitionOffsets(endingOffsets, isStartingOffsets = false))
} finally {
kafkaOffsetReader.close()
}
@@ -100,18 +102,14 @@ private[kafka010] class KafkaRelation(
// Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays.
val executorKafkaParams =
KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId)
+ val toInternalRow = if (includeHeaders) {
+ converter.toInternalRowWithHeaders
+ } else {
+ converter.toInternalRowWithoutHeaders
+ }
val rdd = new KafkaSourceRDD(
sqlContext.sparkContext, executorKafkaParams, offsetRanges,
- pollTimeoutMs, failOnDataLoss).map { cr =>
- InternalRow(
- cr.key,
- cr.value,
- UTF8String.fromString(cr.topic),
- cr.partition,
- cr.offset,
- DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)),
- cr.timestampType.id)
- }
+ pollTimeoutMs, failOnDataLoss).map(toInternalRow)
sqlContext.internalCreateDataFrame(rdd.setName("kafka"), schema).rdd
}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
index d1a35ec53bc94..e1392b6215d3a 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
@@ -31,12 +31,11 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.kafka010.KafkaSource._
-import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
+import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
/**
* A [[Source]] that reads data from Kafka using the following design.
@@ -84,13 +83,15 @@ private[kafka010] class KafkaSource(
private val sc = sqlContext.sparkContext
- private val pollTimeoutMs = sourceOptions.getOrElse(
- KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
- (sc.conf.get(NETWORK_TIMEOUT) * 1000L).toString
- ).toLong
+ private val pollTimeoutMs =
+ sourceOptions.getOrElse(CONSUMER_POLL_TIMEOUT, (sc.conf.get(NETWORK_TIMEOUT) * 1000L).toString)
+ .toLong
private val maxOffsetsPerTrigger =
- sourceOptions.get(KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER).map(_.toLong)
+ sourceOptions.get(MAX_OFFSET_PER_TRIGGER).map(_.toLong)
+
+ private val includeHeaders =
+ sourceOptions.getOrElse(INCLUDE_HEADERS, "false").toBoolean
/**
* Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only
@@ -104,6 +105,8 @@ private[kafka010] class KafkaSource(
case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets())
case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets(None))
case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss)
+ case SpecificTimestampRangeLimit(p) =>
+ kafkaReader.fetchSpecificTimestampBasedOffsets(p, failsOnNoMatchingOffset = true)
}
metadataLog.add(0, offsets)
logInfo(s"Initial offsets: $offsets")
@@ -113,7 +116,9 @@ private[kafka010] class KafkaSource(
private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None
- override def schema: StructType = KafkaOffsetReader.kafkaSchema
+ private val converter = new KafkaRecordToRowConverter()
+
+ override def schema: StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders)
/** Returns the maximum available offset for this source. */
override def getOffset: Option[Offset] = {
@@ -223,7 +228,7 @@ private[kafka010] class KafkaSource(
val deletedPartitions = fromPartitionOffsets.keySet.diff(untilPartitionOffsets.keySet)
if (deletedPartitions.nonEmpty) {
val message = if (kafkaReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
- s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}"
+ s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}"
} else {
s"$deletedPartitions are gone. Some data may have been missed."
}
@@ -267,16 +272,14 @@ private[kafka010] class KafkaSource(
}.toArray
// Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays.
- val rdd = new KafkaSourceRDD(
- sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss).map { cr =>
- InternalRow(
- cr.key,
- cr.value,
- UTF8String.fromString(cr.topic),
- cr.partition,
- cr.offset,
- DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)),
- cr.timestampType.id)
+ val rdd = if (includeHeaders) {
+ new KafkaSourceRDD(
+ sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss)
+ .map(converter.toInternalRowWithHeaders)
+ } else {
+ new KafkaSourceRDD(
+ sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss)
+ .map(converter.toInternalRowWithoutHeaders)
}
logInfo("GetBatch generating RDD of offset range: " +
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala
index 90d70439c5329..b9674a30aee39 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.kafka010
import org.apache.kafka.common.TopicPartition
+import org.apache.spark.sql.connector.read.streaming.PartitionOffset
import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset}
-import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset
/**
* An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
index c3f0be4be96e2..c15f08d78741d 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
@@ -30,14 +30,13 @@ import org.apache.spark.internal.Logging
import org.apache.spark.kafka010.KafkaConfigUpdater
import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability, TableProvider}
+import org.apache.spark.sql.connector.read.{Batch, Scan, ScanBuilder}
+import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream}
+import org.apache.spark.sql.connector.write.{BatchWrite, WriteBuilder}
+import org.apache.spark.sql.connector.write.streaming.StreamingWrite
import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.sources._
-import org.apache.spark.sql.sources.v2._
-import org.apache.spark.sql.sources.v2.TableCapability._
-import org.apache.spark.sql.sources.v2.reader.{Batch, Scan, ScanBuilder}
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream}
-import org.apache.spark.sql.sources.v2.writer.{BatchWrite, WriteBuilder}
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -70,7 +69,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
validateStreamOptions(caseInsensitiveParameters)
require(schema.isEmpty, "Kafka source has a fixed schema and cannot be set with a custom one")
- (shortName(), KafkaOffsetReader.kafkaSchema)
+ val includeHeaders = caseInsensitiveParameters.getOrElse(INCLUDE_HEADERS, "false").toBoolean
+ (shortName(), KafkaRecordToRowConverter.kafkaSchema(includeHeaders))
}
override def createSource(
@@ -89,7 +89,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveParameters)
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
- caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
+ caseInsensitiveParameters, STARTING_OFFSETS_BY_TIMESTAMP_OPTION_KEY,
+ STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
val kafkaOffsetReader = new KafkaOffsetReader(
strategy(caseInsensitiveParameters),
@@ -108,7 +109,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}
override def getTable(options: CaseInsensitiveStringMap): KafkaTable = {
- new KafkaTable
+ val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)
+ new KafkaTable(includeHeaders)
}
/**
@@ -125,19 +127,24 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveParameters)
val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
- caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
+ caseInsensitiveParameters, STARTING_OFFSETS_BY_TIMESTAMP_OPTION_KEY,
+ STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
assert(startingRelationOffsets != LatestOffsetRangeLimit)
val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
- caseInsensitiveParameters, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
+ caseInsensitiveParameters, ENDING_OFFSETS_BY_TIMESTAMP_OPTION_KEY,
+ ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
assert(endingRelationOffsets != EarliestOffsetRangeLimit)
+ val includeHeaders = caseInsensitiveParameters.getOrElse(INCLUDE_HEADERS, "false").toBoolean
+
new KafkaRelation(
sqlContext,
strategy(caseInsensitiveParameters),
sourceOptions = caseInsensitiveParameters,
specifiedKafkaParams = specifiedKafkaParams,
failOnDataLoss = failOnDataLoss(caseInsensitiveParameters),
+ includeHeaders = includeHeaders,
startingOffsets = startingRelationOffsets,
endingOffsets = endingRelationOffsets)
}
@@ -317,13 +324,17 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
// Stream specific options
params.get(ENDING_OFFSETS_OPTION_KEY).map(_ =>
throw new IllegalArgumentException("ending offset not valid in streaming queries"))
+ params.get(ENDING_OFFSETS_BY_TIMESTAMP_OPTION_KEY).map(_ =>
+ throw new IllegalArgumentException("ending timestamp not valid in streaming queries"))
+
validateGeneralOptions(params)
}
private def validateBatchOptions(params: CaseInsensitiveMap[String]) = {
// Batch specific options
KafkaSourceProvider.getKafkaOffsetRangeLimit(
- params, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) match {
+ params, STARTING_OFFSETS_BY_TIMESTAMP_OPTION_KEY, STARTING_OFFSETS_OPTION_KEY,
+ EarliestOffsetRangeLimit) match {
case EarliestOffsetRangeLimit => // good to go
case LatestOffsetRangeLimit =>
throw new IllegalArgumentException("starting offset can't be latest " +
@@ -335,10 +346,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
"be latest for batch queries on Kafka")
case _ => // ignore
}
+ case _: SpecificTimestampRangeLimit => // good to go
}
KafkaSourceProvider.getKafkaOffsetRangeLimit(
- params, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) match {
+ params, ENDING_OFFSETS_BY_TIMESTAMP_OPTION_KEY, ENDING_OFFSETS_OPTION_KEY,
+ LatestOffsetRangeLimit) match {
case EarliestOffsetRangeLimit =>
throw new IllegalArgumentException("ending offset can't be earliest " +
"for batch queries on Kafka")
@@ -350,6 +363,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
"earliest for batch queries on Kafka")
case _ => // ignore
}
+ case _: SpecificTimestampRangeLimit => // good to go
}
validateGeneralOptions(params)
@@ -360,13 +374,14 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}
}
- class KafkaTable extends Table with SupportsRead with SupportsWrite {
+ class KafkaTable(includeHeaders: Boolean) extends Table with SupportsRead with SupportsWrite {
override def name(): String = "KafkaTable"
- override def schema(): StructType = KafkaOffsetReader.kafkaSchema
+ override def schema(): StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders)
override def capabilities(): ju.Set[TableCapability] = {
+ import TableCapability._
// ACCEPT_ANY_SCHEMA is needed because of the following reasons:
// * Kafka writer validates the schema instead of the SQL analyzer (the schema is fixed)
// * Read schema differs from write schema (please see Kafka integration guide)
@@ -403,8 +418,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}
class KafkaScan(options: CaseInsensitiveStringMap) extends Scan {
+ val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)
- override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema
+ override def readSchema(): StructType = {
+ KafkaRecordToRowConverter.kafkaSchema(includeHeaders)
+ }
override def toBatch(): Batch = {
val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
@@ -412,10 +430,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions)
val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
- caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
+ caseInsensitiveOptions, STARTING_OFFSETS_BY_TIMESTAMP_OPTION_KEY,
+ STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
- caseInsensitiveOptions, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
+ caseInsensitiveOptions, ENDING_OFFSETS_BY_TIMESTAMP_OPTION_KEY,
+ ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
new KafkaBatch(
strategy(caseInsensitiveOptions),
@@ -423,7 +443,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
specifiedKafkaParams,
failOnDataLoss(caseInsensitiveOptions),
startingRelationOffsets,
- endingRelationOffsets)
+ endingRelationOffsets,
+ includeHeaders)
}
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
@@ -437,7 +458,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions)
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
- caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
+ caseInsensitiveOptions, STARTING_OFFSETS_BY_TIMESTAMP_OPTION_KEY,
+ STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
val kafkaOffsetReader = new KafkaOffsetReader(
strategy(caseInsensitiveOptions),
@@ -465,7 +487,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions)
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
- caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
+ caseInsensitiveOptions, STARTING_OFFSETS_BY_TIMESTAMP_OPTION_KEY,
+ STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
val kafkaOffsetReader = new KafkaOffsetReader(
strategy(caseInsensitiveOptions),
@@ -491,6 +514,8 @@ private[kafka010] object KafkaSourceProvider extends Logging {
private val STRATEGY_OPTION_KEYS = Set(SUBSCRIBE, SUBSCRIBE_PATTERN, ASSIGN)
private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets"
private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets"
+ private[kafka010] val STARTING_OFFSETS_BY_TIMESTAMP_OPTION_KEY = "startingoffsetsbytimestamp"
+ private[kafka010] val ENDING_OFFSETS_BY_TIMESTAMP_OPTION_KEY = "endingoffsetsbytimestamp"
private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss"
private[kafka010] val MIN_PARTITIONS_OPTION_KEY = "minpartitions"
private[kafka010] val MAX_OFFSET_PER_TRIGGER = "maxoffsetspertrigger"
@@ -498,6 +523,7 @@ private[kafka010] object KafkaSourceProvider extends Logging {
private[kafka010] val FETCH_OFFSET_RETRY_INTERVAL_MS = "fetchoffset.retryintervalms"
private[kafka010] val CONSUMER_POLL_TIMEOUT = "kafkaconsumer.polltimeoutms"
private val GROUP_ID_PREFIX = "groupidprefix"
+ private[kafka010] val INCLUDE_HEADERS = "includeheaders"
val TOPIC_OPTION_KEY = "topic"
@@ -533,15 +559,20 @@ private[kafka010] object KafkaSourceProvider extends Logging {
def getKafkaOffsetRangeLimit(
params: CaseInsensitiveMap[String],
+ offsetByTimestampOptionKey: String,
offsetOptionKey: String,
defaultOffsets: KafkaOffsetRangeLimit): KafkaOffsetRangeLimit = {
- params.get(offsetOptionKey).map(_.trim) match {
- case Some(offset) if offset.toLowerCase(Locale.ROOT) == "latest" =>
- LatestOffsetRangeLimit
- case Some(offset) if offset.toLowerCase(Locale.ROOT) == "earliest" =>
- EarliestOffsetRangeLimit
- case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json))
- case None => defaultOffsets
+ params.get(offsetByTimestampOptionKey).map(_.trim) match {
+ case Some(json) => SpecificTimestampRangeLimit(JsonUtils.partitionTimestamps(json))
+ case None =>
+ params.get(offsetOptionKey).map(_.trim) match {
+ case Some(offset) if offset.toLowerCase(Locale.ROOT) == "latest" =>
+ LatestOffsetRangeLimit
+ case Some(offset) if offset.toLowerCase(Locale.ROOT) == "earliest" =>
+ EarliestOffsetRangeLimit
+ case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json))
+ case None => defaultOffsets
+ }
}
}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala
index 6dd1d2984a96e..2b50b771e694e 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala
@@ -20,9 +20,9 @@ package org.apache.spark.sql.kafka010
import java.{util => ju}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
+import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery
-import org.apache.spark.sql.sources.v2.writer._
-import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.types.StructType
/**
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
index 041fac7717635..b423ddc959c1b 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
@@ -19,9 +19,13 @@ package org.apache.spark.sql.kafka010
import java.{util => ju}
+import scala.collection.JavaConverters._
+
import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata}
+import org.apache.kafka.common.header.Header
+import org.apache.kafka.common.header.internals.RecordHeader
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
import org.apache.spark.sql.types.{BinaryType, StringType}
@@ -88,7 +92,17 @@ private[kafka010] abstract class KafkaRowWriter(
throw new NullPointerException(s"null topic present in the data. Use the " +
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
}
- val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value)
+ val record = if (projectedRow.isNullAt(3)) {
+ new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value)
+ } else {
+ val headerArray = projectedRow.getArray(3)
+ val headers = (0 until headerArray.numElements()).map { i =>
+ val struct = headerArray.getStruct(i, 2)
+ new RecordHeader(struct.getUTF8String(0).toString, struct.getBinary(1))
+ .asInstanceOf[Header]
+ }
+ new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value, headers.asJava)
+ }
producer.send(record, callback)
}
@@ -131,9 +145,26 @@ private[kafka010] abstract class KafkaRowWriter(
throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " +
s"attribute unsupported type ${t.catalogString}")
}
+ val headersExpression = inputSchema
+ .find(_.name == KafkaWriter.HEADERS_ATTRIBUTE_NAME).getOrElse(
+ Literal(CatalystTypeConverters.convertToCatalyst(null),
+ KafkaRecordToRowConverter.headersType)
+ )
+ headersExpression.dataType match {
+ case KafkaRecordToRowConverter.headersType => // good
+ case t =>
+ throw new IllegalStateException(s"${KafkaWriter.HEADERS_ATTRIBUTE_NAME} " +
+ s"attribute unsupported type ${t.catalogString}")
+ }
UnsafeProjection.create(
- Seq(topicExpression, Cast(keyExpression, BinaryType),
- Cast(valueExpression, BinaryType)), inputSchema)
+ Seq(
+ topicExpression,
+ Cast(keyExpression, BinaryType),
+ Cast(valueExpression, BinaryType),
+ headersExpression
+ ),
+ inputSchema
+ )
}
}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
index e1a9191cc5a84..bbb060356f730 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
@@ -21,9 +21,10 @@ import java.{util => ju}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SparkSession}
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
-import org.apache.spark.sql.types.{BinaryType, StringType}
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.types.{BinaryType, MapType, StringType}
import org.apache.spark.util.Utils
/**
@@ -39,6 +40,7 @@ private[kafka010] object KafkaWriter extends Logging {
val TOPIC_ATTRIBUTE_NAME: String = "topic"
val KEY_ATTRIBUTE_NAME: String = "key"
val VALUE_ATTRIBUTE_NAME: String = "value"
+ val HEADERS_ATTRIBUTE_NAME: String = "headers"
override def toString: String = "KafkaWriter"
@@ -75,6 +77,15 @@ private[kafka010] object KafkaWriter extends Logging {
throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " +
s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}")
}
+ schema.find(_.name == HEADERS_ATTRIBUTE_NAME).getOrElse(
+ Literal(CatalystTypeConverters.convertToCatalyst(null),
+ KafkaRecordToRowConverter.headersType)
+ ).dataType match {
+ case KafkaRecordToRowConverter.headersType => // good
+ case _ =>
+ throw new AnalysisException(s"$HEADERS_ATTRIBUTE_NAME attribute type " +
+ s"must be a ${KafkaRecordToRowConverter.headersType.catalogString}")
+ }
}
def write(
diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/commits/0 b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/commits/0
new file mode 100644
index 0000000000000..9c1e3021c3ead
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/commits/0
@@ -0,0 +1,2 @@
+v1
+{"nextBatchWatermarkMs":0}
\ No newline at end of file
diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/metadata b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/metadata
new file mode 100644
index 0000000000000..f1b5ab7aa17f0
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/metadata
@@ -0,0 +1 @@
+{"id":"fc415a71-f0a2-4c3c-aeaf-f9e258c3f726"}
\ No newline at end of file
diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/offsets/0 b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/offsets/0
new file mode 100644
index 0000000000000..5dbadea57acbe
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/offsets/0
@@ -0,0 +1,3 @@
+v1
+{"batchWatermarkMs":0,"batchTimestampMs":1568508285207,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"5"}}
+{"spark-test-topic-2b8619f5-d3c4-4c2d-b5d1-8d9d9458aa62":{"2":3,"4":3,"1":3,"3":3,"0":3}}
\ No newline at end of file
diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/sources/0/0 b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/sources/0/0
new file mode 100644
index 0000000000000..8cf9f8e009ce8
Binary files /dev/null and b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/sources/0/0 differ
diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/0/1.delta b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/0/1.delta
new file mode 100644
index 0000000000000..5815bbdcc2467
Binary files /dev/null and b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/0/1.delta differ
diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/1/1.delta b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/1/1.delta
new file mode 100644
index 0000000000000..e1a065b2b1c78
Binary files /dev/null and b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/1/1.delta differ
diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/2/1.delta b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/2/1.delta
new file mode 100644
index 0000000000000..cce14294e0044
Binary files /dev/null and b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/2/1.delta differ
diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/3/1.delta b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/3/1.delta
new file mode 100644
index 0000000000000..57063019503bc
Binary files /dev/null and b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/3/1.delta differ
diff --git a/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/4/1.delta b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/4/1.delta
new file mode 100644
index 0000000000000..e8b1e4bdc8dba
Binary files /dev/null and b/external/kafka-0-10-sql/src/test/resources/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/state/0/4/1.delta differ
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala
index 80f9a1b410d2c..122fe752615ad 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.kafka010
+import java.{util => ju}
+import java.nio.charset.StandardCharsets
import java.util.concurrent.{Executors, TimeUnit}
import scala.collection.JavaConverters._
@@ -29,10 +31,14 @@ import org.apache.kafka.common.serialization.ByteArrayDeserializer
import org.scalatest.PrivateMethodTester
import org.apache.spark.{TaskContext, TaskContextImpl}
+import org.apache.spark.kafka010.KafkaDelegationTokenTest
import org.apache.spark.sql.kafka010.KafkaDataConsumer.CacheKey
import org.apache.spark.sql.test.SharedSparkSession
-class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester {
+class KafkaDataConsumerSuite
+ extends SharedSparkSession
+ with PrivateMethodTester
+ with KafkaDelegationTokenTest {
protected var testUtils: KafkaTestUtils = _
private val topic = "topic" + Random.nextInt()
@@ -65,6 +71,8 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
private var consumerPool: InternalKafkaConsumerPool = _
override def beforeEach(): Unit = {
+ super.beforeEach()
+
fetchedDataPool = {
val fetchedDataPoolMethod = PrivateMethod[FetchedDataPool]('fetchedDataPool)
KafkaDataConsumer.invokePrivate(fetchedDataPoolMethod())
@@ -91,53 +99,93 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
test("new KafkaDataConsumer instance in case of Task retry") {
try {
val kafkaParams = getKafkaParams()
- val key = new CacheKey(groupId, topicPartition)
+ val key = CacheKey(groupId, topicPartition)
val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null)
TaskContext.setTaskContext(context1)
- val consumer1 = KafkaDataConsumer.acquire(topicPartition, kafkaParams)
-
- // any method call which requires consumer is necessary
- consumer1.getAvailableOffsetRange()
-
- val consumer1Underlying = consumer1._consumer
- assert(consumer1Underlying.isDefined)
-
- consumer1.release()
-
- assert(consumerPool.size(key) === 1)
- // check whether acquired object is available in pool
- val pooledObj = consumerPool.borrowObject(key, kafkaParams)
- assert(consumer1Underlying.get.eq(pooledObj))
- consumerPool.returnObject(pooledObj)
+ val consumer1Underlying = initSingleConsumer(kafkaParams, key)
val context2 = new TaskContextImpl(0, 0, 0, 0, 1, null, null, null)
TaskContext.setTaskContext(context2)
- val consumer2 = KafkaDataConsumer.acquire(topicPartition, kafkaParams)
-
- // any method call which requires consumer is necessary
- consumer2.getAvailableOffsetRange()
+ val consumer2Underlying = initSingleConsumer(kafkaParams, key)
- val consumer2Underlying = consumer2._consumer
- assert(consumer2Underlying.isDefined)
// here we expect different consumer as pool will invalidate for task reattempt
- assert(consumer2Underlying.get.ne(consumer1Underlying.get))
+ assert(consumer2Underlying.ne(consumer1Underlying))
+ } finally {
+ TaskContext.unset()
+ }
+ }
- consumer2.release()
+ test("same KafkaDataConsumer instance in case of same token") {
+ try {
+ val kafkaParams = getKafkaParams()
+ val key = new CacheKey(groupId, topicPartition)
+
+ val context = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null)
+ TaskContext.setTaskContext(context)
+ setSparkEnv(
+ Map(
+ s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers" -> bootStrapServers
+ )
+ )
+ addTokenToUGI(tokenService1, tokenId1, tokenPassword1)
+ val consumer1Underlying = initSingleConsumer(kafkaParams, key)
+ val consumer2Underlying = initSingleConsumer(kafkaParams, key)
+
+ assert(consumer2Underlying.eq(consumer1Underlying))
+ } finally {
+ TaskContext.unset()
+ }
+ }
- // The first consumer should be removed from cache, but the consumer after invalidate
- // should be cached.
- assert(consumerPool.size(key) === 1)
- val pooledObj2 = consumerPool.borrowObject(key, kafkaParams)
- assert(consumer2Underlying.get.eq(pooledObj2))
- consumerPool.returnObject(pooledObj2)
+ test("new KafkaDataConsumer instance in case of token renewal") {
+ try {
+ val kafkaParams = getKafkaParams()
+ val key = new CacheKey(groupId, topicPartition)
+
+ val context = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null)
+ TaskContext.setTaskContext(context)
+ setSparkEnv(
+ Map(
+ s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers" -> bootStrapServers
+ )
+ )
+ addTokenToUGI(tokenService1, tokenId1, tokenPassword1)
+ val consumer1Underlying = initSingleConsumer(kafkaParams, key)
+ addTokenToUGI(tokenService1, tokenId2, tokenPassword2)
+ val consumer2Underlying = initSingleConsumer(kafkaParams, key)
+
+ assert(consumer2Underlying.ne(consumer1Underlying))
} finally {
TaskContext.unset()
}
}
+ private def initSingleConsumer(
+ kafkaParams: ju.Map[String, Object],
+ key: CacheKey): InternalKafkaConsumer = {
+ val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams)
+
+ // any method call which requires consumer is necessary
+ consumer.getOrRetrieveConsumer()
+
+ val consumerUnderlying = consumer._consumer
+ assert(consumerUnderlying.isDefined)
+
+ consumer.release()
+
+ assert(consumerPool.size(key) === 1)
+ // check whether acquired object is available in pool
+ val pooledObj = consumerPool.borrowObject(key, kafkaParams)
+ assert(consumerUnderlying.get.eq(pooledObj))
+ consumerPool.returnObject(pooledObj)
+
+ consumerUnderlying.get
+ }
+
test("SPARK-23623: concurrent use of KafkaDataConsumer") {
- val data: immutable.IndexedSeq[String] = prepareTestTopicHavingTestMessages(topic)
+ val data: immutable.IndexedSeq[(String, Seq[(String, Array[Byte])])] =
+ prepareTestTopicHavingTestMessages(topic)
val topicPartition = new TopicPartition(topic, 0)
val kafkaParams = getKafkaParams()
@@ -157,10 +205,22 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
try {
val range = consumer.getAvailableOffsetRange()
val rcvd = range.earliest until range.latest map { offset =>
- val bytes = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false).value()
- new String(bytes)
+ val record = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false)
+ val value = new String(record.value(), StandardCharsets.UTF_8)
+ val headers = record.headers().toArray.map(header => (header.key(), header.value())).toSeq
+ (value, headers)
+ }
+ data.zip(rcvd).foreach { case (expected, actual) =>
+ // value
+ assert(expected._1 === actual._1)
+ // headers
+ expected._2.zip(actual._2).foreach { case (l, r) =>
+ // header key
+ assert(l._1 === r._1)
+ // header value
+ assert(l._2 === r._2)
+ }
}
- assert(rcvd == data)
} catch {
case e: Throwable =>
error = e
@@ -307,9 +367,12 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
}
private def prepareTestTopicHavingTestMessages(topic: String) = {
- val data = (1 to 1000).map(_.toString)
+ val data = (1 to 1000).map(i => (i.toString, Seq[(String, Array[Byte])]()))
testUtils.createTopic(topic, 1)
- testUtils.sendMessages(topic, data.toArray)
+ val messages = data.map { case (value, hdrs) =>
+ new RecordBuilder(topic, value).headers(hdrs).build()
+ }
+ testUtils.sendMessages(messages)
data
}
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenSuite.scala
index 9850a91f34f63..306483825ae3b 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenSuite.scala
@@ -82,7 +82,6 @@ class KafkaDelegationTokenSuite extends StreamTest with SharedSparkSession with
.format("kafka")
.option("checkpointLocation", checkpointDir.getCanonicalPath)
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
- .option("kafka.security.protocol", SASL_PLAINTEXT.name)
.option("topic", topic)
.start()
@@ -99,7 +98,6 @@ class KafkaDelegationTokenSuite extends StreamTest with SharedSparkSession with
val streamingDf = spark.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
- .option("kafka.security.protocol", SASL_PLAINTEXT.name)
.option("startingOffsets", s"earliest")
.option("subscribe", topic)
.load()
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
index ae8a6886b2b4d..3ee59e57a6edf 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
@@ -28,13 +28,15 @@ import scala.collection.JavaConverters._
import scala.io.Source
import scala.util.Random
+import org.apache.commons.io.FileUtils
import org.apache.kafka.clients.producer.{ProducerRecord, RecordMetadata}
import org.apache.kafka.common.TopicPartition
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.time.SpanSugar._
-import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession}
+import org.apache.spark.sql.{Dataset, ForeachWriter, Row, SparkSession}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.connector.read.streaming.SparkDataStream
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.streaming._
@@ -42,11 +44,11 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.functions.{count, window}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
-import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.util.Utils
abstract class KafkaSourceTest extends StreamTest with SharedSparkSession with KafkaTest {
@@ -677,7 +679,8 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase {
})
}
- private def testGroupId(groupIdKey: String, validateGroupId: (String, Iterable[String]) => Unit) {
+ private def testGroupId(groupIdKey: String,
+ validateGroupId: (String, Iterable[String]) => Unit): Unit = {
// Tests code path KafkaSourceProvider.{sourceSchema(.), createSource(.)}
// as well as KafkaOffsetReader.createConsumer(.)
val topic = newTopic()
@@ -1162,6 +1165,63 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase {
intercept[IllegalArgumentException] { test(minPartitions = "-1", 1, true) }
}
+ test("default config of includeHeader doesn't break existing query from Spark 2.4") {
+ import testImplicits._
+
+ // This topic name is migrated from Spark 2.4.3 test run
+ val topic = "spark-test-topic-2b8619f5-d3c4-4c2d-b5d1-8d9d9458aa62"
+ // create same topic and messages as test run
+ testUtils.createTopic(topic, partitions = 5, overwrite = true)
+ testUtils.sendMessages(topic, Array(-20, -21, -22).map(_.toString), Some(0))
+ testUtils.sendMessages(topic, Array(-10, -11, -12).map(_.toString), Some(1))
+ testUtils.sendMessages(topic, Array(0, 1, 2).map(_.toString), Some(2))
+ testUtils.sendMessages(topic, Array(10, 11, 12).map(_.toString), Some(3))
+ testUtils.sendMessages(topic, Array(20, 21, 22).map(_.toString), Some(4))
+ require(testUtils.getLatestOffsets(Set(topic)).size === 5)
+
+ val headers = Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8)))
+ (31 to 35).map { num =>
+ new RecordBuilder(topic, num.toString).partition(num - 31).headers(headers).build()
+ }.foreach { rec => testUtils.sendMessage(rec) }
+
+ val kafka = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("kafka.metadata.max.age.ms", "1")
+ .option("subscribePattern", topic)
+ .option("startingOffsets", "earliest")
+ .load()
+
+ val query = kafka.dropDuplicates()
+ .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+ .as[(String, String)]
+ .map(kv => kv._2.toInt + 1)
+
+ val resourceUri = this.getClass.getResource(
+ "/structured-streaming/checkpoint-version-2.4.3-kafka-include-headers-default/").toURI
+
+ val checkpointDir = Utils.createTempDir().getCanonicalFile
+ // Copy the checkpoint to a temp dir to prevent changes to the original.
+ // Not doing this will lead to the test passing on the first run, but fail subsequent runs.
+ FileUtils.copyDirectory(new File(resourceUri), checkpointDir)
+
+ testStream(query)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ /*
+ Note: The checkpoint was generated using the following input in Spark version 2.4.3
+ testUtils.createTopic(topic, partitions = 5, overwrite = true)
+
+ testUtils.sendMessages(topic, Array(-20, -21, -22).map(_.toString), Some(0))
+ testUtils.sendMessages(topic, Array(-10, -11, -12).map(_.toString), Some(1))
+ testUtils.sendMessages(topic, Array(0, 1, 2).map(_.toString), Some(2))
+ testUtils.sendMessages(topic, Array(10, 11, 12).map(_.toString), Some(3))
+ testUtils.sendMessages(topic, Array(20, 21, 22).map(_.toString), Some(4))
+ */
+ makeSureGetOffsetCalled,
+ CheckNewAnswer(32, 33, 34, 35, 36)
+ )
+ }
}
abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
@@ -1219,6 +1279,16 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
"failOnDataLoss" -> failOnDataLoss.toString)
}
+ test(s"assign from specific timestamps (failOnDataLoss: $failOnDataLoss)") {
+ val topic = newTopic()
+ testFromSpecificTimestamps(
+ topic,
+ failOnDataLoss = failOnDataLoss,
+ addPartitions = false,
+ "assign" -> assignString(topic, 0 to 4),
+ "failOnDataLoss" -> failOnDataLoss.toString)
+ }
+
test(s"subscribing topic by name from latest offsets (failOnDataLoss: $failOnDataLoss)") {
val topic = newTopic()
testFromLatestOffsets(
@@ -1242,6 +1312,12 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
testFromSpecificOffsets(topic, failOnDataLoss = failOnDataLoss, "subscribe" -> topic)
}
+ test(s"subscribing topic by name from specific timestamps (failOnDataLoss: $failOnDataLoss)") {
+ val topic = newTopic()
+ testFromSpecificTimestamps(topic, failOnDataLoss = failOnDataLoss, addPartitions = true,
+ "subscribe" -> topic)
+ }
+
test(s"subscribing topic by pattern from latest offsets (failOnDataLoss: $failOnDataLoss)") {
val topicPrefix = newTopic()
val topic = topicPrefix + "-suffix"
@@ -1270,6 +1346,17 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
failOnDataLoss = failOnDataLoss,
"subscribePattern" -> s"$topicPrefix-.*")
}
+
+ test(s"subscribing topic by pattern from specific timestamps " +
+ s"(failOnDataLoss: $failOnDataLoss)") {
+ val topicPrefix = newTopic()
+ val topic = topicPrefix + "-suffix"
+ testFromSpecificTimestamps(
+ topic,
+ failOnDataLoss = failOnDataLoss,
+ addPartitions = true,
+ "subscribePattern" -> s"$topicPrefix-.*")
+ }
}
test("bad source options") {
@@ -1289,6 +1376,9 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
// Specifying an ending offset
testBadOptions("endingOffsets" -> "latest")("Ending offset not valid in streaming queries")
+ testBadOptions("subscribe" -> "t", "endingOffsetsByTimestamp" -> "{\"t\": {\"0\": 1000}}")(
+ "Ending timestamp not valid in streaming queries")
+
// No strategy specified
testBadOptions()("options must be specified", "subscribe", "subscribePattern")
@@ -1337,7 +1427,8 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
(STARTING_OFFSETS_OPTION_KEY, """{"topic-A":{"0":23}}""",
SpecificOffsetRangeLimit(Map(new TopicPartition("topic-A", 0) -> 23))))) {
val offset = getKafkaOffsetRangeLimit(
- CaseInsensitiveMap[String](Map(optionKey -> optionValue)), optionKey, answer)
+ CaseInsensitiveMap[String](Map(optionKey -> optionValue)), "dummy", optionKey,
+ answer)
assert(offset === answer)
}
@@ -1345,7 +1436,7 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
(STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit),
(ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit))) {
val offset = getKafkaOffsetRangeLimit(
- CaseInsensitiveMap[String](Map.empty), optionKey, answer)
+ CaseInsensitiveMap[String](Map.empty), "dummy", optionKey, answer)
assert(offset === answer)
}
}
@@ -1410,11 +1501,90 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
)
}
+ private def testFromSpecificTimestamps(
+ topic: String,
+ failOnDataLoss: Boolean,
+ addPartitions: Boolean,
+ options: (String, String)*): Unit = {
+ def sendMessages(topic: String, msgs: Seq[String], part: Int, ts: Long): Unit = {
+ val records = msgs.map { msg =>
+ new RecordBuilder(topic, msg).partition(part).timestamp(ts).build()
+ }
+ testUtils.sendMessages(records)
+ }
+
+ testUtils.createTopic(topic, partitions = 5)
+
+ val firstTimestamp = System.currentTimeMillis() - 5000
+ sendMessages(topic, Array(-20).map(_.toString), 0, firstTimestamp)
+ sendMessages(topic, Array(-10).map(_.toString), 1, firstTimestamp)
+ sendMessages(topic, Array(0, 1).map(_.toString), 2, firstTimestamp)
+ sendMessages(topic, Array(10, 11).map(_.toString), 3, firstTimestamp)
+ sendMessages(topic, Array(20, 21, 22).map(_.toString), 4, firstTimestamp)
+
+ val secondTimestamp = firstTimestamp + 1000
+ sendMessages(topic, Array(-21, -22).map(_.toString), 0, secondTimestamp)
+ sendMessages(topic, Array(-11, -12).map(_.toString), 1, secondTimestamp)
+ sendMessages(topic, Array(2).map(_.toString), 2, secondTimestamp)
+ sendMessages(topic, Array(12).map(_.toString), 3, secondTimestamp)
+ // no data after second timestamp for partition 4
+
+ require(testUtils.getLatestOffsets(Set(topic)).size === 5)
+
+ // we intentionally starts from second timestamp,
+ // except for partition 4 - it starts from first timestamp
+ val startPartitionTimestamps: Map[TopicPartition, Long] = Map(
+ (0 to 3).map(new TopicPartition(topic, _) -> secondTimestamp): _*
+ ) ++ Map(new TopicPartition(topic, 4) -> firstTimestamp)
+ val startingTimestamps = JsonUtils.partitionTimestamps(startPartitionTimestamps)
+
+ val reader = spark
+ .readStream
+ .format("kafka")
+ .option("startingOffsetsByTimestamp", startingTimestamps)
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("kafka.metadata.max.age.ms", "1")
+ .option("failOnDataLoss", failOnDataLoss.toString)
+ options.foreach { case (k, v) => reader.option(k, v) }
+ val kafka = reader.load()
+ .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+ .as[(String, String)]
+ val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt)
+
+ testStream(mapped)(
+ makeSureGetOffsetCalled,
+ Execute { q =>
+ val partitions = (0 to 4).map(new TopicPartition(topic, _))
+ // wait to reach the last offset in every partition
+ q.awaitOffset(
+ 0, KafkaSourceOffset(partitions.map(tp => tp -> 3L).toMap), streamingTimeout.toMillis)
+ },
+ CheckAnswer(-21, -22, -11, -12, 2, 12, 20, 21, 22),
+ StopStream,
+ StartStream(),
+ CheckAnswer(-21, -22, -11, -12, 2, 12, 20, 21, 22), // Should get the data back on recovery
+ StopStream,
+ AddKafkaData(Set(topic), 30, 31, 32), // Add data when stream is stopped
+ StartStream(),
+ CheckAnswer(-21, -22, -11, -12, 2, 12, 20, 21, 22, 30, 31, 32), // Should get the added data
+ AssertOnQuery("Add partitions") { query: StreamExecution =>
+ if (addPartitions) setTopicPartitions(topic, 10, query)
+ true
+ },
+ AddKafkaData(Set(topic), 40, 41, 42, 43, 44)(ensureDataInMultiplePartition = true),
+ CheckAnswer(-21, -22, -11, -12, 2, 12, 20, 21, 22, 30, 31, 32, 40, 41, 42, 43, 44),
+ StopStream
+ )
+ }
+
test("Kafka column types") {
val now = System.currentTimeMillis()
val topic = newTopic()
testUtils.createTopic(newTopic(), partitions = 1)
- testUtils.sendMessages(topic, Array(1).map(_.toString))
+ testUtils.sendMessage(
+ new RecordBuilder(topic, "1")
+ .headers(Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8)))).build()
+ )
val kafka = spark
.readStream
@@ -1423,6 +1593,7 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
.option("kafka.metadata.max.age.ms", "1")
.option("startingOffsets", s"earliest")
.option("subscribe", topic)
+ .option("includeHeaders", "true")
.load()
val query = kafka
@@ -1445,6 +1616,21 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
// producer. So here we just use a low bound to make sure the internal conversion works.
assert(row.getAs[java.sql.Timestamp]("timestamp").getTime >= now, s"Unexpected results: $row")
assert(row.getAs[Int]("timestampType") === 0, s"Unexpected results: $row")
+
+ def checkHeader(row: Row, expected: Seq[(String, Array[Byte])]): Unit = {
+ // array>
+ val headers = row.getList[Row](row.fieldIndex("headers")).asScala
+ assert(headers.length === expected.length)
+
+ (0 until expected.length).foreach { idx =>
+ val key = headers(idx).getAs[String]("key")
+ val value = headers(idx).getAs[Array[Byte]]("value")
+ assert(key === expected(idx)._1)
+ assert(value === expected(idx)._2)
+ }
+ }
+
+ checkHeader(row, Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8))))
query.stop()
}
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala
index b4e1b78c7db4e..063e2e2bc8b77 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala
@@ -17,9 +17,11 @@
package org.apache.spark.sql.kafka010
+import java.nio.charset.StandardCharsets.UTF_8
import java.util.Locale
import java.util.concurrent.atomic.AtomicInteger
+import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.util.Random
@@ -27,11 +29,11 @@ import org.apache.kafka.clients.producer.ProducerRecord
import org.apache.kafka.common.TopicPartition
import org.apache.spark.SparkConf
-import org.apache.spark.sql.QueryTest
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{DataFrameReader, QueryTest}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils
@@ -70,7 +72,8 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession
protected def createDF(
topic: String,
withOptions: Map[String, String] = Map.empty[String, String],
- brokerAddress: Option[String] = None) = {
+ brokerAddress: Option[String] = None,
+ includeHeaders: Boolean = false) = {
val df = spark
.read
.format("kafka")
@@ -80,7 +83,13 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession
withOptions.foreach {
case (key, value) => df.option(key, value)
}
- df.load().selectExpr("CAST(value AS STRING)")
+ if (includeHeaders) {
+ df.option("includeHeaders", "true")
+ df.load()
+ .selectExpr("CAST(value AS STRING)", "headers")
+ } else {
+ df.load().selectExpr("CAST(value AS STRING)")
+ }
}
test("explicit earliest to latest offsets") {
@@ -147,6 +156,214 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession
checkAnswer(df, (0 to 30).map(_.toString).toDF)
}
+ test("default starting and ending offsets with headers") {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 3)
+ testUtils.sendMessage(
+ new RecordBuilder(topic, "1").headers(Seq()).partition(0).build()
+ )
+ testUtils.sendMessage(
+ new RecordBuilder(topic, "2").headers(
+ Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8)))).partition(1).build()
+ )
+ testUtils.sendMessage(
+ new RecordBuilder(topic, "3").headers(
+ Seq(("e", "f".getBytes(UTF_8)), ("e", "g".getBytes(UTF_8)))).partition(2).build()
+ )
+
+ // Implicit offset values, should default to earliest and latest
+ val df = createDF(topic, includeHeaders = true)
+ // Test that we default to "earliest" and "latest"
+ checkAnswer(df, Seq(("1", null),
+ ("2", Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8)))),
+ ("3", Seq(("e", "f".getBytes(UTF_8)), ("e", "g".getBytes(UTF_8))))).toDF)
+ }
+
+ test("timestamp provided for starting and ending") {
+ val (topic, timestamps) = prepareTimestampRelatedUnitTest
+
+ // timestamp both presented: starting "first" ending "finalized"
+ verifyTimestampRelatedQueryResult({ df =>
+ val startPartitionTimestamps: Map[TopicPartition, Long] = Map(
+ (0 to 2).map(new TopicPartition(topic, _) -> timestamps(1)): _*)
+ val startingTimestamps = JsonUtils.partitionTimestamps(startPartitionTimestamps)
+
+ val endPartitionTimestamps = Map(
+ (0 to 2).map(new TopicPartition(topic, _) -> timestamps(2)): _*)
+ val endingTimestamps = JsonUtils.partitionTimestamps(endPartitionTimestamps)
+
+ df.option("startingOffsetsByTimestamp", startingTimestamps)
+ .option("endingOffsetsByTimestamp", endingTimestamps)
+ }, topic, 10 to 19)
+ }
+
+ test("timestamp provided for starting, offset provided for ending") {
+ val (topic, timestamps) = prepareTimestampRelatedUnitTest
+
+ // starting only presented as "first", and ending presented as endingOffsets
+ verifyTimestampRelatedQueryResult({ df =>
+ val startTopicTimestamps = Map(
+ (0 to 2).map(new TopicPartition(topic, _) -> timestamps.head): _*)
+ val startingTimestamps = JsonUtils.partitionTimestamps(startTopicTimestamps)
+
+ val endPartitionOffsets = Map(
+ new TopicPartition(topic, 0) -> -1L, // -1 => latest
+ new TopicPartition(topic, 1) -> -1L,
+ new TopicPartition(topic, 2) -> 1L // explicit offset - take only first one
+ )
+ val endingOffsets = JsonUtils.partitionOffsets(endPartitionOffsets)
+
+ // so we here expect full of records from partition 0 and 1, and only the first record
+ // from partition 2 which is "2"
+
+ df.option("startingOffsetsByTimestamp", startingTimestamps)
+ .option("endingOffsets", endingOffsets)
+ }, topic, (0 to 29).filterNot(_ % 3 == 2) ++ Seq(2))
+ }
+
+ test("timestamp provided for ending, offset provided for starting") {
+ val (topic, timestamps) = prepareTimestampRelatedUnitTest
+
+ // ending only presented as "third", and starting presented as startingOffsets
+ verifyTimestampRelatedQueryResult({ df =>
+ val startPartitionOffsets = Map(
+ new TopicPartition(topic, 0) -> -2L, // -2 => earliest
+ new TopicPartition(topic, 1) -> -2L,
+ new TopicPartition(topic, 2) -> 0L // explicit earliest
+ )
+ val startingOffsets = JsonUtils.partitionOffsets(startPartitionOffsets)
+
+ val endTopicTimestamps = Map(
+ (0 to 2).map(new TopicPartition(topic, _) -> timestamps(2)): _*)
+ val endingTimestamps = JsonUtils.partitionTimestamps(endTopicTimestamps)
+
+ df.option("startingOffsets", startingOffsets)
+ .option("endingOffsetsByTimestamp", endingTimestamps)
+ }, topic, 0 to 19)
+ }
+
+ test("timestamp provided for starting, ending not provided") {
+ val (topic, timestamps) = prepareTimestampRelatedUnitTest
+
+ // starting only presented as "second", and ending not presented
+ verifyTimestampRelatedQueryResult({ df =>
+ val startTopicTimestamps = Map(
+ (0 to 2).map(new TopicPartition(topic, _) -> timestamps(1)): _*)
+ val startingTimestamps = JsonUtils.partitionTimestamps(startTopicTimestamps)
+
+ df.option("startingOffsetsByTimestamp", startingTimestamps)
+ }, topic, 10 to 29)
+ }
+
+ test("timestamp provided for ending, starting not provided") {
+ val (topic, timestamps) = prepareTimestampRelatedUnitTest
+
+ // ending only presented as "third", and starting not presented
+ verifyTimestampRelatedQueryResult({ df =>
+ val endTopicTimestamps = Map(
+ (0 to 2).map(new TopicPartition(topic, _) -> timestamps(2)): _*)
+ val endingTimestamps = JsonUtils.partitionTimestamps(endTopicTimestamps)
+
+ df.option("endingOffsetsByTimestamp", endingTimestamps)
+ }, topic, 0 to 19)
+ }
+
+ test("no matched offset for timestamp - startingOffsets") {
+ val (topic, timestamps) = prepareTimestampRelatedUnitTest
+
+ val e = intercept[SparkException] {
+ verifyTimestampRelatedQueryResult({ df =>
+ // partition 2 will make query fail
+ val startTopicTimestamps = Map(
+ (0 to 1).map(new TopicPartition(topic, _) -> timestamps(1)): _*) ++
+ Map(new TopicPartition(topic, 2) -> Long.MaxValue)
+
+ val startingTimestamps = JsonUtils.partitionTimestamps(startTopicTimestamps)
+
+ df.option("startingOffsetsByTimestamp", startingTimestamps)
+ }, topic, Seq.empty)
+ }
+
+ @tailrec
+ def assertionErrorInExceptionChain(e: Throwable): Boolean = {
+ if (e.isInstanceOf[AssertionError]) {
+ true
+ } else if (e.getCause == null) {
+ false
+ } else {
+ assertionErrorInExceptionChain(e.getCause)
+ }
+ }
+
+ assert(assertionErrorInExceptionChain(e),
+ "Cannot find expected AssertionError in chained exceptions")
+ }
+
+ test("no matched offset for timestamp - endingOffsets") {
+ val (topic, timestamps) = prepareTimestampRelatedUnitTest
+
+ // the query will run fine, since we allow no matching offset for timestamp
+ // if it's endingOffsets
+ // for partition 0 and 1, it only takes records between first and second timestamp
+ // for partition 2, it will take all records
+ verifyTimestampRelatedQueryResult({ df =>
+ val endTopicTimestamps = Map(
+ (0 to 1).map(new TopicPartition(topic, _) -> timestamps(1)): _*) ++
+ Map(new TopicPartition(topic, 2) -> Long.MaxValue)
+
+ val endingTimestamps = JsonUtils.partitionTimestamps(endTopicTimestamps)
+
+ df.option("endingOffsetsByTimestamp", endingTimestamps)
+ }, topic, (0 to 9) ++ (10 to 29).filter(_ % 3 == 2))
+ }
+
+ private def prepareTimestampRelatedUnitTest: (String, Seq[Long]) = {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 3)
+
+ def sendMessages(topic: String, msgs: Array[String], part: Int, ts: Long): Unit = {
+ val records = msgs.map { msg =>
+ new RecordBuilder(topic, msg).partition(part).timestamp(ts).build()
+ }
+ testUtils.sendMessages(records)
+ }
+
+ val firstTimestamp = System.currentTimeMillis() - 5000
+ (0 to 2).foreach { partNum =>
+ sendMessages(topic, (0 to 9).filter(_ % 3 == partNum)
+ .map(_.toString).toArray, partNum, firstTimestamp)
+ }
+
+ val secondTimestamp = firstTimestamp + 1000
+ (0 to 2).foreach { partNum =>
+ sendMessages(topic, (10 to 19).filter(_ % 3 == partNum)
+ .map(_.toString).toArray, partNum, secondTimestamp)
+ }
+
+ val thirdTimestamp = secondTimestamp + 1000
+ (0 to 2).foreach { partNum =>
+ sendMessages(topic, (20 to 29).filter(_ % 3 == partNum)
+ .map(_.toString).toArray, partNum, thirdTimestamp)
+ }
+
+ val finalizedTimestamp = thirdTimestamp + 1000
+
+ (topic, Seq(firstTimestamp, secondTimestamp, thirdTimestamp, finalizedTimestamp))
+ }
+
+ private def verifyTimestampRelatedQueryResult(
+ optionFn: DataFrameReader => DataFrameReader,
+ topic: String,
+ expectation: Seq[Int]): Unit = {
+ val df = spark.read
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+
+ val df2 = optionFn(df).load().selectExpr("CAST(value AS STRING)")
+ checkAnswer(df2, expectation.map(_.toString).toDF)
+ }
+
test("reuse same dataframe in query") {
// This test ensures that we do not cache the Kafka Consumer in KafkaRelation
val topic = newTopic()
@@ -263,7 +480,8 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession
})
}
- private def testGroupId(groupIdKey: String, validateGroupId: (String, Iterable[String]) => Unit) {
+ private def testGroupId(groupIdKey: String,
+ validateGroupId: (String, Iterable[String]) => Unit): Unit = {
// Tests code path KafkaSourceProvider.createRelation(.)
val topic = newTopic()
testUtils.createTopic(topic, partitions = 3)
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
index 84ad41610cccd..d77b9a3b6a9e1 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.kafka010
+import java.nio.charset.StandardCharsets.UTF_8
import java.util.Locale
import java.util.concurrent.atomic.AtomicInteger
@@ -32,7 +33,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.{BinaryType, DataType}
+import org.apache.spark.sql.types.{BinaryType, DataType, StringType, StructField, StructType}
abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with KafkaTest {
protected var testUtils: KafkaTestUtils = _
@@ -59,13 +60,14 @@ abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with
protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
- protected def createKafkaReader(topic: String): DataFrame = {
+ protected def createKafkaReader(topic: String, includeHeaders: Boolean = false): DataFrame = {
spark.read
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("startingOffsets", "earliest")
.option("endingOffsets", "latest")
.option("subscribe", topic)
+ .option("includeHeaders", includeHeaders.toString)
.load()
}
}
@@ -368,15 +370,52 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase {
test("batch - write to kafka") {
val topic = newTopic()
testUtils.createTopic(topic)
- val df = Seq("1", "2", "3", "4", "5").map(v => (topic, v)).toDF("topic", "value")
+ val data = Seq(
+ Row(topic, "1", Seq(
+ Row("a", "b".getBytes(UTF_8))
+ )),
+ Row(topic, "2", Seq(
+ Row("c", "d".getBytes(UTF_8)),
+ Row("e", "f".getBytes(UTF_8))
+ )),
+ Row(topic, "3", Seq(
+ Row("g", "h".getBytes(UTF_8)),
+ Row("g", "i".getBytes(UTF_8))
+ )),
+ Row(topic, "4", null),
+ Row(topic, "5", Seq(
+ Row("j", "k".getBytes(UTF_8)),
+ Row("j", "l".getBytes(UTF_8)),
+ Row("m", "n".getBytes(UTF_8))
+ ))
+ )
+
+ val df = spark.createDataFrame(
+ spark.sparkContext.parallelize(data),
+ StructType(Seq(StructField("topic", StringType), StructField("value", StringType),
+ StructField("headers", KafkaRecordToRowConverter.headersType)))
+ )
+
df.write
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("topic", topic)
+ .mode("append")
.save()
checkAnswer(
- createKafkaReader(topic).selectExpr("CAST(value as STRING) value"),
- Row("1") :: Row("2") :: Row("3") :: Row("4") :: Row("5") :: Nil)
+ createKafkaReader(topic, includeHeaders = true).selectExpr(
+ "CAST(value as STRING) value", "headers"
+ ),
+ Row("1", Seq(Row("a", "b".getBytes(UTF_8)))) ::
+ Row("2", Seq(Row("c", "d".getBytes(UTF_8)), Row("e", "f".getBytes(UTF_8)))) ::
+ Row("3", Seq(Row("g", "h".getBytes(UTF_8)), Row("g", "i".getBytes(UTF_8)))) ::
+ Row("4", null) ::
+ Row("5", Seq(
+ Row("j", "k".getBytes(UTF_8)),
+ Row("j", "l".getBytes(UTF_8)),
+ Row("m", "n".getBytes(UTF_8)))) ::
+ Nil
+ )
}
test("batch - null topic field value, and no topic option") {
@@ -385,12 +424,13 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase {
df.write
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .mode("append")
.save()
}
TestUtils.assertExceptionMsg(ex, "null topic present in the data")
}
- protected def testUnsupportedSaveModes(msg: (SaveMode) => String) {
+ protected def testUnsupportedSaveModes(msg: (SaveMode) => String): Unit = {
val topic = newTopic()
testUtils.createTopic(topic)
val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value")
@@ -419,6 +459,7 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase {
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("topic", topic)
+ .mode("append")
.save()
}
}
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala
index 8e6de88865e06..f7b00b31ebba0 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.mockito.Mockito.{mock, when}
import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite}
-import org.apache.spark.sql.sources.v2.reader.Scan
+import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.util.CaseInsensitiveStringMap
class KafkaSourceProviderSuite extends SparkFunSuite {
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
index d7cb30f530396..bbb72bf9973e3 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.kafka010
import java.io.{File, IOException}
import java.lang.{Integer => JInt}
-import java.net.InetSocketAddress
+import java.net.{InetAddress, InetSocketAddress}
import java.nio.charset.StandardCharsets
import java.util.{Collections, Map => JMap, Properties, UUID}
import java.util.concurrent.TimeUnit
@@ -41,6 +41,8 @@ import org.apache.kafka.clients.consumer.KafkaConsumer
import org.apache.kafka.clients.producer._
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.config.SaslConfigs
+import org.apache.kafka.common.header.Header
+import org.apache.kafka.common.header.internals.RecordHeader
import org.apache.kafka.common.network.ListenerName
import org.apache.kafka.common.security.auth.SecurityProtocol.{PLAINTEXT, SASL_PLAINTEXT}
import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer}
@@ -66,10 +68,13 @@ class KafkaTestUtils(
private val JAVA_AUTH_CONFIG = "java.security.auth.login.config"
+ private val localCanonicalHostName = InetAddress.getLoopbackAddress().getCanonicalHostName()
+ logInfo(s"Local host name is $localCanonicalHostName")
+
private var kdc: MiniKdc = _
// Zookeeper related configurations
- private val zkHost = "localhost"
+ private val zkHost = localCanonicalHostName
private var zkPort: Int = 0
private val zkConnectionTimeout = 60000
private val zkSessionTimeout = 10000
@@ -78,12 +83,12 @@ class KafkaTestUtils(
private var zkUtils: ZkUtils = _
// Kafka broker related configurations
- private val brokerHost = "localhost"
+ private val brokerHost = localCanonicalHostName
private var brokerPort = 0
private var brokerConf: KafkaConfig = _
private val brokerServiceName = "kafka"
- private val clientUser = "client/localhost"
+ private val clientUser = s"client/$localCanonicalHostName"
private var clientKeytabFile: File = _
// Kafka broker server
@@ -137,17 +142,17 @@ class KafkaTestUtils(
assert(kdcReady, "KDC should be set up beforehand")
val baseDir = Utils.createTempDir()
- val zkServerUser = "zookeeper/localhost"
+ val zkServerUser = s"zookeeper/$localCanonicalHostName"
val zkServerKeytabFile = new File(baseDir, "zookeeper.keytab")
kdc.createPrincipal(zkServerKeytabFile, zkServerUser)
logDebug(s"Created keytab file: ${zkServerKeytabFile.getAbsolutePath()}")
- val zkClientUser = "zkclient/localhost"
+ val zkClientUser = s"zkclient/$localCanonicalHostName"
val zkClientKeytabFile = new File(baseDir, "zkclient.keytab")
kdc.createPrincipal(zkClientKeytabFile, zkClientUser)
logDebug(s"Created keytab file: ${zkClientKeytabFile.getAbsolutePath()}")
- val kafkaServerUser = "kafka/localhost"
+ val kafkaServerUser = s"kafka/$localCanonicalHostName"
val kafkaServerKeytabFile = new File(baseDir, "kafka.keytab")
kdc.createPrincipal(kafkaServerKeytabFile, kafkaServerUser)
logDebug(s"Created keytab file: ${kafkaServerKeytabFile.getAbsolutePath()}")
@@ -348,38 +353,33 @@ class KafkaTestUtils(
}
}
- /** Java-friendly function for sending messages to the Kafka broker */
- def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = {
- sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*))
+ def sendMessages(topic: String, msgs: Array[String]): Seq[(String, RecordMetadata)] = {
+ sendMessages(topic, msgs, None)
}
- /** Send the messages to the Kafka broker */
- def sendMessages(topic: String, messageToFreq: Map[String, Int]): Unit = {
- val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray
- sendMessages(topic, messages)
+ def sendMessages(
+ topic: String,
+ msgs: Array[String],
+ part: Option[Int]): Seq[(String, RecordMetadata)] = {
+ val records = msgs.map { msg =>
+ val builder = new RecordBuilder(topic, msg)
+ part.foreach { p => builder.partition(p) }
+ builder.build()
+ }
+ sendMessages(records)
}
- /** Send the array of messages to the Kafka broker */
- def sendMessages(topic: String, messages: Array[String]): Seq[(String, RecordMetadata)] = {
- sendMessages(topic, messages, None)
+ def sendMessage(msg: ProducerRecord[String, String]): Seq[(String, RecordMetadata)] = {
+ sendMessages(Array(msg))
}
- /** Send the array of messages to the Kafka broker using specified partition */
- def sendMessages(
- topic: String,
- messages: Array[String],
- partition: Option[Int]): Seq[(String, RecordMetadata)] = {
+ def sendMessages(msgs: Seq[ProducerRecord[String, String]]): Seq[(String, RecordMetadata)] = {
producer = new KafkaProducer[String, String](producerConfiguration)
val offsets = try {
- messages.map { m =>
- val record = partition match {
- case Some(p) => new ProducerRecord[String, String](topic, p, null, m)
- case None => new ProducerRecord[String, String](topic, m)
- }
- val metadata =
- producer.send(record).get(10, TimeUnit.SECONDS)
- logInfo(s"\tSent $m to partition ${metadata.partition}, offset ${metadata.offset}")
- (m, metadata)
+ msgs.map { msg =>
+ val metadata = producer.send(msg).get(10, TimeUnit.SECONDS)
+ logInfo(s"\tSent ($msg) to partition ${metadata.partition}, offset ${metadata.offset}")
+ (msg.value(), metadata)
}
} finally {
if (producer != null) {
@@ -550,7 +550,7 @@ class KafkaTestUtils(
zkUtils: ZkUtils,
topic: String,
numPartitions: Int,
- servers: Seq[KafkaServer]) {
+ servers: Seq[KafkaServer]): Unit = {
eventually(timeout(1.minute), interval(200.milliseconds)) {
try {
verifyTopicDeletion(topic, numPartitions, servers)
@@ -613,7 +613,7 @@ class KafkaTestUtils(
val actualPort = factory.getLocalPort
- def shutdown() {
+ def shutdown(): Unit = {
factory.shutdown()
// The directories are not closed even if the ZooKeeper server is shut down.
// Please see ZOOKEEPER-1844, which is fixed in 3.4.6+. It leads to test failures
@@ -634,4 +634,3 @@ class KafkaTestUtils(
}
}
}
-
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/RecordBuilder.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/RecordBuilder.scala
new file mode 100644
index 0000000000000..ef07798442e56
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/RecordBuilder.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.lang.{Integer => JInt, Long => JLong}
+
+import scala.collection.JavaConverters._
+
+import org.apache.kafka.clients.producer.ProducerRecord
+import org.apache.kafka.common.header.Header
+import org.apache.kafka.common.header.internals.RecordHeader
+
+class RecordBuilder(topic: String, value: String) {
+ var _partition: Option[JInt] = None
+ var _timestamp: Option[JLong] = None
+ var _key: Option[String] = None
+ var _headers: Option[Seq[(String, Array[Byte])]] = None
+
+ def partition(part: JInt): RecordBuilder = {
+ _partition = Some(part)
+ this
+ }
+
+ def partition(part: Int): RecordBuilder = {
+ _partition = Some(part.intValue())
+ this
+ }
+
+ def timestamp(ts: JLong): RecordBuilder = {
+ _timestamp = Some(ts)
+ this
+ }
+
+ def timestamp(ts: Long): RecordBuilder = {
+ _timestamp = Some(ts.longValue())
+ this
+ }
+
+ def key(k: String): RecordBuilder = {
+ _key = Some(k)
+ this
+ }
+
+ def headers(hdrs: Seq[(String, Array[Byte])]): RecordBuilder = {
+ _headers = Some(hdrs)
+ this
+ }
+
+ def build(): ProducerRecord[String, String] = {
+ val part = _partition.orNull
+ val ts = _timestamp.orNull
+ val k = _key.orNull
+ val hdrs = _headers.map { h =>
+ h.map { case (k, v) => new RecordHeader(k, v).asInstanceOf[Header] }
+ }.map(_.asJava).orNull
+
+ new ProducerRecord[String, String](topic, part, ts, k, value, hdrs)
+ }
+}
diff --git a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaConfigUpdater.scala b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaConfigUpdater.scala
index 0c61045d6d487..f54ff0d146f7a 100644
--- a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaConfigUpdater.scala
+++ b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaConfigUpdater.scala
@@ -57,6 +57,12 @@ private[spark] case class KafkaConfigUpdater(module: String, kafkaParams: Map[St
}
def setAuthenticationConfigIfNeeded(): this.type = {
+ val clusterConfig = KafkaTokenUtil.findMatchingTokenClusterConfig(SparkEnv.get.conf,
+ kafkaParams(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG).asInstanceOf[String])
+ setAuthenticationConfigIfNeeded(clusterConfig)
+ }
+
+ def setAuthenticationConfigIfNeeded(clusterConfig: Option[KafkaTokenClusterConf]): this.type = {
// There are multiple possibilities to log in and applied in the following order:
// - JVM global security provided -> try to log in with JVM global security configuration
// which can be configured for example with 'java.security.auth.login.config'.
@@ -66,10 +72,9 @@ private[spark] case class KafkaConfigUpdater(module: String, kafkaParams: Map[St
if (KafkaTokenUtil.isGlobalJaasConfigurationProvided) {
logDebug("JVM global security configuration detected, using it for login.")
} else {
- val clusterConfig = KafkaTokenUtil.findMatchingToken(SparkEnv.get.conf,
- map.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG).asInstanceOf[String])
clusterConfig.foreach { clusterConf =>
logDebug("Delegation token detected, using it for login.")
+ setIfUnset(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, clusterConf.securityProtocol)
val jaasParams = KafkaTokenUtil.getTokenJaasParams(clusterConf)
set(SaslConfigs.SASL_JAAS_CONFIG, jaasParams)
require(clusterConf.tokenMechanism.startsWith("SCRAM"),
diff --git a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenSparkConf.scala b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenSparkConf.scala
index e1f3c800a51f8..ed4a6f1e34c55 100644
--- a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenSparkConf.scala
+++ b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenSparkConf.scala
@@ -57,6 +57,7 @@ private [kafka010] object KafkaTokenSparkConf extends Logging {
val CLUSTERS_CONFIG_PREFIX = "spark.kafka.clusters."
val DEFAULT_TARGET_SERVERS_REGEX = ".*"
val DEFAULT_SASL_KERBEROS_SERVICE_NAME = "kafka"
+ val DEFAULT_SECURITY_PROTOCOL_CONFIG = SASL_SSL.name
val DEFAULT_SASL_TOKEN_MECHANISM = "SCRAM-SHA-512"
def getClusterConfig(sparkConf: SparkConf, identifier: String): KafkaTokenClusterConf = {
@@ -72,7 +73,8 @@ private [kafka010] object KafkaTokenSparkConf extends Logging {
s"${configPrefix}auth.${CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG}")),
sparkClusterConf.getOrElse(s"target.${CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG}.regex",
KafkaTokenSparkConf.DEFAULT_TARGET_SERVERS_REGEX),
- sparkClusterConf.getOrElse(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, SASL_SSL.name),
+ sparkClusterConf.getOrElse(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG,
+ DEFAULT_SECURITY_PROTOCOL_CONFIG),
sparkClusterConf.getOrElse(SaslConfigs.SASL_KERBEROS_SERVICE_NAME,
KafkaTokenSparkConf.DEFAULT_SASL_KERBEROS_SERVICE_NAME),
sparkClusterConf.get(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG),
diff --git a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenUtil.scala b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenUtil.scala
index 39e3ac74a9aeb..0ebe98330b4ae 100644
--- a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenUtil.scala
+++ b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaTokenUtil.scala
@@ -36,7 +36,7 @@ import org.apache.kafka.common.security.auth.SecurityProtocol.{SASL_PLAINTEXT, S
import org.apache.kafka.common.security.scram.ScramLoginModule
import org.apache.kafka.common.security.token.delegation.DelegationToken
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
@@ -241,8 +241,8 @@ private[spark] object KafkaTokenUtil extends Logging {
"TOKENID", "HMAC", "OWNER", "RENEWERS", "ISSUEDATE", "EXPIRYDATE", "MAXDATE"))
val tokenInfo = token.tokenInfo
logDebug("%-15s %-15s %-15s %-25s %-15s %-15s %-15s".format(
- REDACTION_REPLACEMENT_TEXT,
tokenInfo.tokenId,
+ REDACTION_REPLACEMENT_TEXT,
tokenInfo.owner,
tokenInfo.renewersAsString,
dateFormat.format(tokenInfo.issueTimestamp),
@@ -251,7 +251,7 @@ private[spark] object KafkaTokenUtil extends Logging {
}
}
- def findMatchingToken(
+ def findMatchingTokenClusterConfig(
sparkConf: SparkConf,
bootStrapServers: String): Option[KafkaTokenClusterConf] = {
val tokens = UserGroupInformation.getCurrentUser().getCredentials.getAllTokens.asScala
@@ -272,6 +272,7 @@ private[spark] object KafkaTokenUtil extends Logging {
def getTokenJaasParams(clusterConf: KafkaTokenClusterConf): String = {
val token = UserGroupInformation.getCurrentUser().getCredentials.getToken(
getTokenService(clusterConf.identifier))
+ require(token != null, s"Token for identifier ${clusterConf.identifier} must exist")
val username = new String(token.getIdentifier)
val password = new String(token.getPassword)
@@ -288,4 +289,17 @@ private[spark] object KafkaTokenUtil extends Logging {
params
}
+
+ def isConnectorUsingCurrentToken(
+ params: ju.Map[String, Object],
+ clusterConfig: Option[KafkaTokenClusterConf]): Boolean = {
+ if (params.containsKey(SaslConfigs.SASL_JAAS_CONFIG)) {
+ logDebug("Delegation token used by connector, checking if uses the latest token.")
+ val consumerJaasParams = params.get(SaslConfigs.SASL_JAAS_CONFIG).asInstanceOf[String]
+ require(clusterConfig.isDefined, "Delegation token must exist for this connector.")
+ getTokenJaasParams(clusterConfig.get) == consumerJaasParams
+ } else {
+ true
+ }
+ }
}
diff --git a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaConfigUpdaterSuite.scala b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaConfigUpdaterSuite.scala
index 7a172892e778c..dc1e7cb8d979e 100644
--- a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaConfigUpdaterSuite.scala
+++ b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaConfigUpdaterSuite.scala
@@ -17,8 +17,13 @@
package org.apache.spark.kafka010
+import java.{util => ju}
+
+import scala.collection.JavaConverters._
+
import org.apache.kafka.clients.CommonClientConfigs
import org.apache.kafka.common.config.SaslConfigs
+import org.apache.kafka.common.security.auth.SecurityProtocol.SASL_PLAINTEXT
import org.apache.spark.SparkFunSuite
@@ -62,36 +67,64 @@ class KafkaConfigUpdaterSuite extends SparkFunSuite with KafkaDelegationTokenTes
}
test("setAuthenticationConfigIfNeeded with global security should not set values") {
- val params = Map.empty[String, String]
+ val params = Map(
+ CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG -> bootStrapServers
+ )
+ setSparkEnv(
+ Map(
+ s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers" -> bootStrapServers
+ )
+ )
setGlobalKafkaClientConfig()
val updatedParams = KafkaConfigUpdater(testModule, params)
.setAuthenticationConfigIfNeeded()
.build()
- assert(updatedParams.size() === 0)
+ assert(updatedParams.asScala === params)
}
test("setAuthenticationConfigIfNeeded with token should set values") {
val params = Map(
CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG -> bootStrapServers
)
+ testWithTokenSetValues(params) { updatedParams =>
+ assert(updatedParams.get(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG) ===
+ KafkaTokenSparkConf.DEFAULT_SECURITY_PROTOCOL_CONFIG)
+ }
+ }
+
+ test("setAuthenticationConfigIfNeeded with token should not override user-defined protocol") {
+ val overrideProtocolName = SASL_PLAINTEXT.name
+ val params = Map(
+ CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG -> bootStrapServers,
+ CommonClientConfigs.SECURITY_PROTOCOL_CONFIG -> overrideProtocolName
+ )
+ testWithTokenSetValues(params) { updatedParams =>
+ assert(updatedParams.get(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG) ===
+ overrideProtocolName)
+ }
+ }
+
+ def testWithTokenSetValues(params: Map[String, String])
+ (validate: (ju.Map[String, Object]) => Unit): Unit = {
setSparkEnv(
Map(
s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers" -> bootStrapServers
)
)
- addTokenToUGI(tokenService1)
+ addTokenToUGI(tokenService1, tokenId1, tokenPassword1)
val updatedParams = KafkaConfigUpdater(testModule, params)
.setAuthenticationConfigIfNeeded()
.build()
- assert(updatedParams.size() === 3)
+ assert(updatedParams.size() === 4)
assert(updatedParams.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG) === bootStrapServers)
assert(updatedParams.containsKey(SaslConfigs.SASL_JAAS_CONFIG))
assert(updatedParams.get(SaslConfigs.SASL_MECHANISM) ===
KafkaTokenSparkConf.DEFAULT_SASL_TOKEN_MECHANISM)
+ validate(updatedParams)
}
test("setAuthenticationConfigIfNeeded with invalid mechanism should throw exception") {
@@ -104,7 +137,7 @@ class KafkaConfigUpdaterSuite extends SparkFunSuite with KafkaDelegationTokenTes
s"spark.kafka.clusters.$identifier1.sasl.token.mechanism" -> "intentionally_invalid"
)
)
- addTokenToUGI(tokenService1)
+ addTokenToUGI(tokenService1, tokenId1, tokenPassword1)
val e = intercept[IllegalArgumentException] {
KafkaConfigUpdater(testModule, params)
diff --git a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaDelegationTokenTest.scala b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaDelegationTokenTest.scala
index eebbf96afa470..19335f4221e40 100644
--- a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaDelegationTokenTest.scala
+++ b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaDelegationTokenTest.scala
@@ -37,8 +37,12 @@ trait KafkaDelegationTokenTest extends BeforeAndAfterEach {
private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*)
- protected val tokenId = "tokenId" + ju.UUID.randomUUID().toString
- protected val tokenPassword = "tokenPassword" + ju.UUID.randomUUID().toString
+ private var savedSparkEnv: SparkEnv = _
+
+ protected val tokenId1 = "tokenId" + ju.UUID.randomUUID().toString
+ protected val tokenPassword1 = "tokenPassword" + ju.UUID.randomUUID().toString
+ protected val tokenId2 = "tokenId" + ju.UUID.randomUUID().toString
+ protected val tokenPassword2 = "tokenPassword" + ju.UUID.randomUUID().toString
protected val identifier1 = "cluster1"
protected val identifier2 = "cluster2"
@@ -72,11 +76,16 @@ trait KafkaDelegationTokenTest extends BeforeAndAfterEach {
}
}
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ savedSparkEnv = SparkEnv.get
+ }
+
override def afterEach(): Unit = {
try {
Configuration.setConfiguration(null)
- UserGroupInformation.setLoginUser(null)
- SparkEnv.set(null)
+ UserGroupInformation.reset()
+ SparkEnv.set(savedSparkEnv)
} finally {
super.afterEach()
}
@@ -86,7 +95,7 @@ trait KafkaDelegationTokenTest extends BeforeAndAfterEach {
Configuration.setConfiguration(new KafkaJaasConfiguration)
}
- protected def addTokenToUGI(tokenService: Text): Unit = {
+ protected def addTokenToUGI(tokenService: Text, tokenId: String, tokenPassword: String): Unit = {
val token = new Token[KafkaDelegationTokenIdentifier](
tokenId.getBytes,
tokenPassword.getBytes,
diff --git a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaRedactionUtilSuite.scala b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaRedactionUtilSuite.scala
index 42a9fb5567b6f..225afbe5f3649 100644
--- a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaRedactionUtilSuite.scala
+++ b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaRedactionUtilSuite.scala
@@ -68,7 +68,7 @@ class KafkaRedactionUtilSuite extends SparkFunSuite with KafkaDelegationTokenTes
test("redactParams should redact token password from parameters") {
setSparkEnv(Map.empty)
val groupId = "id-" + ju.UUID.randomUUID().toString
- addTokenToUGI(tokenService1)
+ addTokenToUGI(tokenService1, tokenId1, tokenPassword1)
val clusterConf = createClusterConf(identifier1, SASL_SSL.name)
val jaasParams = KafkaTokenUtil.getTokenJaasParams(clusterConf)
val kafkaParams = Seq(
@@ -81,8 +81,8 @@ class KafkaRedactionUtilSuite extends SparkFunSuite with KafkaDelegationTokenTes
assert(redactedParams.size === 2)
assert(redactedParams.get(ConsumerConfig.GROUP_ID_CONFIG).get === groupId)
val redactedJaasParams = redactedParams.get(SaslConfigs.SASL_JAAS_CONFIG).get
- assert(redactedJaasParams.contains(tokenId))
- assert(!redactedJaasParams.contains(tokenPassword))
+ assert(redactedJaasParams.contains(tokenId1))
+ assert(!redactedJaasParams.contains(tokenPassword1))
}
test("redactParams should redact passwords from parameters") {
@@ -113,13 +113,13 @@ class KafkaRedactionUtilSuite extends SparkFunSuite with KafkaDelegationTokenTes
}
test("redactJaasParam should redact token password") {
- addTokenToUGI(tokenService1)
+ addTokenToUGI(tokenService1, tokenId1, tokenPassword1)
val clusterConf = createClusterConf(identifier1, SASL_SSL.name)
val jaasParams = KafkaTokenUtil.getTokenJaasParams(clusterConf)
val redactedJaasParams = KafkaRedactionUtil.redactJaasParam(jaasParams)
- assert(redactedJaasParams.contains(tokenId))
- assert(!redactedJaasParams.contains(tokenPassword))
+ assert(redactedJaasParams.contains(tokenId1))
+ assert(!redactedJaasParams.contains(tokenPassword1))
}
}
diff --git a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaTokenUtilSuite.scala b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaTokenUtilSuite.scala
index 5496195b41490..6fa1b56bff977 100644
--- a/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaTokenUtilSuite.scala
+++ b/external/kafka-0-10-token-provider/src/test/scala/org/apache/spark/kafka010/KafkaTokenUtilSuite.scala
@@ -17,15 +17,18 @@
package org.apache.spark.kafka010
+import java.{util => ju}
import java.security.PrivilegedExceptionAction
+import scala.collection.JavaConverters._
+
import org.apache.hadoop.io.Text
import org.apache.hadoop.security.UserGroupInformation
import org.apache.kafka.clients.CommonClientConfigs
import org.apache.kafka.common.config.{SaslConfigs, SslConfigs}
import org.apache.kafka.common.security.auth.SecurityProtocol.{SASL_PLAINTEXT, SASL_SSL, SSL}
-import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite}
import org.apache.spark.internal.config._
class KafkaTokenUtilSuite extends SparkFunSuite with KafkaDelegationTokenTest {
@@ -174,58 +177,102 @@ class KafkaTokenUtilSuite extends SparkFunSuite with KafkaDelegationTokenTest {
assert(KafkaTokenUtil.isGlobalJaasConfigurationProvided)
}
- test("findMatchingToken without token should return None") {
- assert(KafkaTokenUtil.findMatchingToken(sparkConf, bootStrapServers) === None)
+ test("findMatchingTokenClusterConfig without token should return None") {
+ assert(KafkaTokenUtil.findMatchingTokenClusterConfig(sparkConf, bootStrapServers) === None)
}
- test("findMatchingToken with non-matching tokens should return None") {
+ test("findMatchingTokenClusterConfig with non-matching tokens should return None") {
sparkConf.set(s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers", bootStrapServers)
sparkConf.set(s"spark.kafka.clusters.$identifier1.target.bootstrap.servers.regex",
nonMatchingTargetServersRegex)
sparkConf.set(s"spark.kafka.clusters.$identifier2.bootstrap.servers", bootStrapServers)
sparkConf.set(s"spark.kafka.clusters.$identifier2.target.bootstrap.servers.regex",
matchingTargetServersRegex)
- addTokenToUGI(tokenService1)
- addTokenToUGI(new Text("intentionally_garbage"))
+ addTokenToUGI(tokenService1, tokenId1, tokenPassword1)
+ addTokenToUGI(new Text("intentionally_garbage"), tokenId1, tokenPassword1)
- assert(KafkaTokenUtil.findMatchingToken(sparkConf, bootStrapServers) === None)
+ assert(KafkaTokenUtil.findMatchingTokenClusterConfig(sparkConf, bootStrapServers) === None)
}
- test("findMatchingToken with one matching token should return cluster configuration") {
+ test("findMatchingTokenClusterConfig with one matching token should return token and cluster " +
+ "configuration") {
sparkConf.set(s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers", bootStrapServers)
sparkConf.set(s"spark.kafka.clusters.$identifier1.target.bootstrap.servers.regex",
matchingTargetServersRegex)
- addTokenToUGI(tokenService1)
+ addTokenToUGI(tokenService1, tokenId1, tokenPassword1)
- assert(KafkaTokenUtil.findMatchingToken(sparkConf, bootStrapServers) ===
- Some(KafkaTokenSparkConf.getClusterConfig(sparkConf, identifier1)))
+ val clusterConfig = KafkaTokenUtil.findMatchingTokenClusterConfig(sparkConf, bootStrapServers)
+ assert(clusterConfig.get === KafkaTokenSparkConf.getClusterConfig(sparkConf, identifier1))
}
- test("findMatchingToken with multiple matching tokens should throw exception") {
+ test("findMatchingTokenClusterConfig with multiple matching tokens should throw exception") {
sparkConf.set(s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers", bootStrapServers)
sparkConf.set(s"spark.kafka.clusters.$identifier1.target.bootstrap.servers.regex",
matchingTargetServersRegex)
sparkConf.set(s"spark.kafka.clusters.$identifier2.auth.bootstrap.servers", bootStrapServers)
sparkConf.set(s"spark.kafka.clusters.$identifier2.target.bootstrap.servers.regex",
matchingTargetServersRegex)
- addTokenToUGI(tokenService1)
- addTokenToUGI(tokenService2)
+ addTokenToUGI(tokenService1, tokenId1, tokenPassword1)
+ addTokenToUGI(tokenService2, tokenId1, tokenPassword1)
val thrown = intercept[IllegalArgumentException] {
- KafkaTokenUtil.findMatchingToken(sparkConf, bootStrapServers)
+ KafkaTokenUtil.findMatchingTokenClusterConfig(sparkConf, bootStrapServers)
}
assert(thrown.getMessage.contains("More than one delegation token matches"))
}
test("getTokenJaasParams with token should return scram module") {
- addTokenToUGI(tokenService1)
+ addTokenToUGI(tokenService1, tokenId1, tokenPassword1)
val clusterConf = createClusterConf(identifier1, SASL_SSL.name)
val jaasParams = KafkaTokenUtil.getTokenJaasParams(clusterConf)
assert(jaasParams.contains("ScramLoginModule required"))
assert(jaasParams.contains("tokenauth=true"))
- assert(jaasParams.contains(tokenId))
- assert(jaasParams.contains(tokenPassword))
+ assert(jaasParams.contains(tokenId1))
+ assert(jaasParams.contains(tokenPassword1))
+ }
+
+ test("isConnectorUsingCurrentToken without security should return true") {
+ val kafkaParams = Map[String, Object]().asJava
+
+ assert(KafkaTokenUtil.isConnectorUsingCurrentToken(kafkaParams, None))
+ }
+
+ test("isConnectorUsingCurrentToken with same token should return true") {
+ setSparkEnv(
+ Map(
+ s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers" -> bootStrapServers
+ )
+ )
+ addTokenToUGI(tokenService1, tokenId1, tokenPassword1)
+ val kafkaParams = getKafkaParams()
+ val clusterConfig = KafkaTokenUtil.findMatchingTokenClusterConfig(SparkEnv.get.conf,
+ kafkaParams.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG).asInstanceOf[String])
+
+ assert(KafkaTokenUtil.isConnectorUsingCurrentToken(kafkaParams, clusterConfig))
+ }
+
+ test("isConnectorUsingCurrentToken with different token should return false") {
+ setSparkEnv(
+ Map(
+ s"spark.kafka.clusters.$identifier1.auth.bootstrap.servers" -> bootStrapServers
+ )
+ )
+ addTokenToUGI(tokenService1, tokenId1, tokenPassword1)
+ val kafkaParams = getKafkaParams()
+ addTokenToUGI(tokenService1, tokenId2, tokenPassword2)
+ val clusterConfig = KafkaTokenUtil.findMatchingTokenClusterConfig(SparkEnv.get.conf,
+ kafkaParams.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG).asInstanceOf[String])
+
+ assert(!KafkaTokenUtil.isConnectorUsingCurrentToken(kafkaParams, clusterConfig))
+ }
+
+ private def getKafkaParams(): ju.Map[String, Object] = {
+ val clusterConf = createClusterConf(identifier1, SASL_SSL.name)
+ Map[String, Object](
+ CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG -> bootStrapServers,
+ SaslConfigs.SASL_JAAS_CONFIG -> KafkaTokenUtil.getTokenJaasParams(clusterConf)
+ ).asJava
}
}
diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml
index 397de87d3cdff..d11569d709b23 100644
--- a/external/kafka-0-10/pom.xml
+++ b/external/kafka-0-10/pom.xml
@@ -45,6 +45,13 @@
${project.version}
provided
+
+ org.apache.spark
+ spark-streaming_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
org.apache.spark
spark-core_${scala.binary.version}
diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala
index 4d3e476e7cc58..925327d9d58e6 100644
--- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala
+++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala
@@ -18,8 +18,8 @@
package org.apache.spark.streaming.kafka010
import java.io.File
-import java.lang.{ Long => JLong }
-import java.util.{ Arrays, HashMap => JHashMap, Map => JMap, UUID }
+import java.lang.{Long => JLong}
+import java.util.{Arrays, HashMap => JHashMap, Map => JMap, UUID}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicLong
@@ -31,13 +31,12 @@ import scala.util.Random
import org.apache.kafka.clients.consumer._
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.serialization.StringDeserializer
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.scalatest.concurrent.Eventually
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time}
+import org.apache.spark.streaming.{LocalStreamingContext, Milliseconds, StreamingContext, Time}
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.scheduler._
import org.apache.spark.streaming.scheduler.rate.RateEstimator
@@ -45,8 +44,7 @@ import org.apache.spark.util.Utils
class DirectKafkaStreamSuite
extends SparkFunSuite
- with BeforeAndAfter
- with BeforeAndAfterAll
+ with LocalStreamingContext
with Eventually
with Logging {
val sparkConf = new SparkConf()
@@ -56,18 +54,17 @@ class DirectKafkaStreamSuite
// Otherwise the poll timeout defaults to 2 minutes and causes test cases to run longer.
.set("spark.streaming.kafka.consumer.poll.ms", "10000")
- private var ssc: StreamingContext = _
private var testDir: File = _
private var kafkaTestUtils: KafkaTestUtils = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
kafkaTestUtils = new KafkaTestUtils
kafkaTestUtils.setup()
}
- override def afterAll() {
+ override def afterAll(): Unit = {
try {
if (kafkaTestUtils != null) {
kafkaTestUtils.teardown()
@@ -78,12 +75,13 @@ class DirectKafkaStreamSuite
}
}
- after {
- if (ssc != null) {
- ssc.stop(stopSparkContext = true)
- }
- if (testDir != null) {
- Utils.deleteRecursively(testDir)
+ override def afterEach(): Unit = {
+ try {
+ if (testDir != null) {
+ Utils.deleteRecursively(testDir)
+ }
+ } finally {
+ super.afterEach()
}
}
@@ -342,7 +340,7 @@ class DirectKafkaStreamSuite
val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest")
// Send data to Kafka
- def sendData(data: Seq[Int]) {
+ def sendData(data: Seq[Int]): Unit = {
val strings = data.map { _.toString}
kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap)
}
@@ -434,7 +432,7 @@ class DirectKafkaStreamSuite
val committed = new ConcurrentHashMap[TopicPartition, OffsetAndMetadata]()
// Send data to Kafka and wait for it to be received
- def sendDataAndWaitForReceive(data: Seq[Int]) {
+ def sendDataAndWaitForReceive(data: Seq[Int]): Unit = {
val strings = data.map { _.toString}
kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap)
eventually(timeout(10.seconds), interval(50.milliseconds)) {
diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
index 431473e7f1d38..82913cf416a5f 100644
--- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
+++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
@@ -27,7 +27,7 @@ import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.serialization.ByteArrayDeserializer
import org.mockito.Mockito.when
import org.scalatest.BeforeAndAfterAll
-import org.scalatest.mockito.MockitoSugar
+import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark._
diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala
index 47bc8fec2c80c..d6123e16dd238 100644
--- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala
+++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala
@@ -47,14 +47,14 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
private var sc: SparkContext = _
- override def beforeAll {
+ override def beforeAll: Unit = {
super.beforeAll()
sc = new SparkContext(sparkConf)
kafkaTestUtils = new KafkaTestUtils
kafkaTestUtils.setup()
}
- override def afterAll {
+ override def afterAll: Unit = {
try {
try {
if (sc != null) {
@@ -81,7 +81,8 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
private val preferredHosts = LocationStrategies.PreferConsistent
- private def compactLogs(topic: String, partition: Int, messages: Array[(String, String)]) {
+ private def compactLogs(topic: String, partition: Int,
+ messages: Array[(String, String)]): Unit = {
val mockTime = new MockTime()
val logs = new Pool[TopicPartition, Log]()
val logDir = kafkaTestUtils.brokerLogDir
diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala
index 5dec9709011e6..999870acfb532 100644
--- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala
+++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala
@@ -316,7 +316,7 @@ private[kafka010] class KafkaTestUtils extends Logging {
val actualPort = factory.getLocalPort
- def shutdown() {
+ def shutdown(): Unit = {
factory.shutdown()
// The directories are not closed even if the ZooKeeper server is shut down.
// Please see ZOOKEEPER-1844, which is fixed in 3.4.6+. It leads to test failures
diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala
index dedd691cd1b23..d38ed9fc9263d 100644
--- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala
+++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala
@@ -45,7 +45,7 @@ private[kafka010] class MockTime(@volatile private var currentMs: Long) extends
override def nanoseconds: Long =
TimeUnit.NANOSECONDS.convert(currentMs, TimeUnit.MILLISECONDS)
- override def sleep(ms: Long) {
+ override def sleep(ms: Long): Unit = {
this.currentMs += ms
scheduler.tick()
}
diff --git a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java
index 86c42df9e8435..31ca2fe5c95ff 100644
--- a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java
+++ b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java
@@ -32,13 +32,14 @@
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
-import org.apache.spark.streaming.kinesis.KinesisUtils;
+import org.apache.spark.streaming.kinesis.KinesisInitialPositions;
+import org.apache.spark.streaming.kinesis.KinesisInputDStream;
import scala.Tuple2;
+import scala.reflect.ClassTag$;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.kinesis.AmazonKinesisClient;
-import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream;
/**
* Consumes messages from a Amazon Kinesis streams and does wordcount.
@@ -135,11 +136,19 @@ public static void main(String[] args) throws Exception {
// Create the Kinesis DStreams
List> streamsList = new ArrayList<>(numStreams);
for (int i = 0; i < numStreams; i++) {
- streamsList.add(
- KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName,
- InitialPositionInStream.LATEST, kinesisCheckpointInterval,
- StorageLevel.MEMORY_AND_DISK_2())
- );
+ streamsList.add(JavaDStream.fromDStream(
+ KinesisInputDStream.builder()
+ .streamingContext(jssc)
+ .checkpointAppName(kinesisAppName)
+ .streamName(streamName)
+ .endpointUrl(endpointUrl)
+ .regionName(regionName)
+ .initialPosition(new KinesisInitialPositions.Latest())
+ .checkpointInterval(kinesisCheckpointInterval)
+ .storageLevel(StorageLevel.MEMORY_AND_DISK_2())
+ .build(),
+ ClassTag$.MODULE$.apply(byte[].class)
+ ));
}
// Union all the streams if there is more than 1 stream
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
index fcb790e3ea1f9..a5d5ac769b28d 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
@@ -73,7 +73,7 @@ import org.apache.spark.streaming.kinesis.KinesisInputDStream
* the Kinesis Spark Streaming integration.
*/
object KinesisWordCountASL extends Logging {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
// Check that all required args were passed in.
if (args.length != 3) {
System.err.println(
@@ -178,7 +178,7 @@ object KinesisWordCountASL extends Logging {
* https://kinesis.us-east-1.amazonaws.com us-east-1 10 5
*/
object KinesisWordProducerASL {
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length != 4) {
System.err.println(
"""
@@ -269,7 +269,7 @@ object KinesisWordProducerASL {
*/
private[streaming] object StreamingExamples extends Logging {
// Set reasonable logging levels for streaming if the user has not configured log4j.
- def setStreamingLogLevels() {
+ def setStreamingLogLevels(): Unit = {
val log4jInitialized = Logger.getRootLogger.getAllAppenders.hasMoreElements
if (!log4jInitialized) {
// We first log something to initialize Spark's default logging, then we override the
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
index 5fb83b26f8382..11e949536f2b6 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
@@ -68,7 +68,7 @@ private[kinesis] class KinesisCheckpointer(
if (checkpointer != null) {
try {
// We must call `checkpoint()` with no parameter to finish reading shards.
- // See an URL below for details:
+ // See a URL below for details:
// https://forums.aws.amazon.com/thread.jspa?threadID=244218
KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100)
} catch {
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
index 608da0b8bf563..8c3931a1c87fd 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
@@ -19,7 +19,9 @@ package org.apache.spark.streaming.kinesis
import scala.reflect.ClassTag
-import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+import collection.JavaConverters._
+import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration}
+import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel
import com.amazonaws.services.kinesis.model.Record
import org.apache.spark.rdd.RDD
@@ -43,7 +45,9 @@ private[kinesis] class KinesisInputDStream[T: ClassTag](
val messageHandler: Record => T,
val kinesisCreds: SparkAWSCredentials,
val dynamoDBCreds: Option[SparkAWSCredentials],
- val cloudWatchCreds: Option[SparkAWSCredentials]
+ val cloudWatchCreds: Option[SparkAWSCredentials],
+ val metricsLevel: MetricsLevel,
+ val metricsEnabledDimensions: Set[String]
) extends ReceiverInputDStream[T](_ssc) {
import KinesisReadConfigurations._
@@ -79,7 +83,8 @@ private[kinesis] class KinesisInputDStream[T: ClassTag](
override def getReceiver(): Receiver[T] = {
new KinesisReceiver(streamName, endpointUrl, regionName, initialPosition,
checkpointAppName, checkpointInterval, _storageLevel, messageHandler,
- kinesisCreds, dynamoDBCreds, cloudWatchCreds)
+ kinesisCreds, dynamoDBCreds, cloudWatchCreds,
+ metricsLevel, metricsEnabledDimensions)
}
}
@@ -104,6 +109,8 @@ object KinesisInputDStream {
private var kinesisCredsProvider: Option[SparkAWSCredentials] = None
private var dynamoDBCredsProvider: Option[SparkAWSCredentials] = None
private var cloudWatchCredsProvider: Option[SparkAWSCredentials] = None
+ private var metricsLevel: Option[MetricsLevel] = None
+ private var metricsEnabledDimensions: Option[Set[String]] = None
/**
* Sets the StreamingContext that will be used to construct the Kinesis DStream. This is a
@@ -237,6 +244,7 @@ object KinesisInputDStream {
* endpoint. Defaults to [[DefaultCredentialsProvider]] if no custom value is specified.
*
* @param credentials [[SparkAWSCredentials]] to use for Kinesis authentication
+ * @return Reference to this [[KinesisInputDStream.Builder]]
*/
def kinesisCredentials(credentials: SparkAWSCredentials): Builder = {
kinesisCredsProvider = Option(credentials)
@@ -248,6 +256,7 @@ object KinesisInputDStream {
* endpoint. Will use the same credentials used for AWS Kinesis if no custom value is set.
*
* @param credentials [[SparkAWSCredentials]] to use for DynamoDB authentication
+ * @return Reference to this [[KinesisInputDStream.Builder]]
*/
def dynamoDBCredentials(credentials: SparkAWSCredentials): Builder = {
dynamoDBCredsProvider = Option(credentials)
@@ -259,12 +268,43 @@ object KinesisInputDStream {
* endpoint. Will use the same credentials used for AWS Kinesis if no custom value is set.
*
* @param credentials [[SparkAWSCredentials]] to use for CloudWatch authentication
+ * @return Reference to this [[KinesisInputDStream.Builder]]
*/
def cloudWatchCredentials(credentials: SparkAWSCredentials): Builder = {
cloudWatchCredsProvider = Option(credentials)
this
}
+ /**
+ * Sets the CloudWatch metrics level. Defaults to
+ * [[KinesisClientLibConfiguration.DEFAULT_METRICS_LEVEL]] if no custom value is specified.
+ *
+ * @param metricsLevel [[MetricsLevel]] to specify the CloudWatch metrics level
+ * @return Reference to this [[KinesisInputDStream.Builder]]
+ * @see
+ * [[https://docs.aws.amazon.com/streams/latest/dev/monitoring-with-kcl.html#metric-levels]]
+ */
+ def metricsLevel(metricsLevel: MetricsLevel): Builder = {
+ this.metricsLevel = Option(metricsLevel)
+ this
+ }
+
+ /**
+ * Sets the enabled CloudWatch metrics dimensions. Defaults to
+ * [[KinesisClientLibConfiguration.DEFAULT_METRICS_ENABLED_DIMENSIONS]]
+ * if no custom value is specified.
+ *
+ * @param metricsEnabledDimensions Set[String] to specify which CloudWatch metrics dimensions
+ * should be enabled
+ * @return Reference to this [[KinesisInputDStream.Builder]]
+ * @see
+ * [[https://docs.aws.amazon.com/streams/latest/dev/monitoring-with-kcl.html#metric-levels]]
+ */
+ def metricsEnabledDimensions(metricsEnabledDimensions: Set[String]): Builder = {
+ this.metricsEnabledDimensions = Option(metricsEnabledDimensions)
+ this
+ }
+
/**
* Create a new instance of [[KinesisInputDStream]] with configured parameters and the provided
* message handler.
@@ -287,7 +327,9 @@ object KinesisInputDStream {
ssc.sc.clean(handler),
kinesisCredsProvider.getOrElse(DefaultCredentials),
dynamoDBCredsProvider,
- cloudWatchCredsProvider)
+ cloudWatchCredsProvider,
+ metricsLevel.getOrElse(DEFAULT_METRICS_LEVEL),
+ metricsEnabledDimensions.getOrElse(DEFAULT_METRICS_ENABLED_DIMENSIONS))
}
/**
@@ -324,4 +366,8 @@ object KinesisInputDStream {
private[kinesis] val DEFAULT_KINESIS_REGION_NAME: String = "us-east-1"
private[kinesis] val DEFAULT_INITIAL_POSITION: KinesisInitialPosition = new Latest()
private[kinesis] val DEFAULT_STORAGE_LEVEL: StorageLevel = StorageLevel.MEMORY_AND_DISK_2
+ private[kinesis] val DEFAULT_METRICS_LEVEL: MetricsLevel =
+ KinesisClientLibConfiguration.DEFAULT_METRICS_LEVEL
+ private[kinesis] val DEFAULT_METRICS_ENABLED_DIMENSIONS: Set[String] =
+ KinesisClientLibConfiguration.DEFAULT_METRICS_ENABLED_DIMENSIONS.asScala.toSet
}
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
index 69c52365b1bf8..6feb8f1b5598f 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
@@ -25,6 +25,7 @@ import scala.util.control.NonFatal
import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory}
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{KinesisClientLibConfiguration, Worker}
+import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel
import com.amazonaws.services.kinesis.model.Record
import org.apache.spark.internal.Logging
@@ -92,7 +93,9 @@ private[kinesis] class KinesisReceiver[T](
messageHandler: Record => T,
kinesisCreds: SparkAWSCredentials,
dynamoDBCreds: Option[SparkAWSCredentials],
- cloudWatchCreds: Option[SparkAWSCredentials])
+ cloudWatchCreds: Option[SparkAWSCredentials],
+ metricsLevel: MetricsLevel,
+ metricsEnabledDimensions: Set[String])
extends Receiver[T](storageLevel) with Logging { receiver =>
/*
@@ -143,7 +146,7 @@ private[kinesis] class KinesisReceiver[T](
* This is called when the KinesisReceiver starts and must be non-blocking.
* The KCL creates and manages the receiving/processing thread pool through Worker.run().
*/
- override def onStart() {
+ override def onStart(): Unit = {
blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler)
workerId = Utils.localHostName() + ":" + UUID.randomUUID()
@@ -162,6 +165,8 @@ private[kinesis] class KinesisReceiver[T](
.withKinesisEndpoint(endpointUrl)
.withTaskBackoffTimeMillis(500)
.withRegionName(regionName)
+ .withMetricsLevel(metricsLevel)
+ .withMetricsEnabledDimensions(metricsEnabledDimensions.asJava)
// Update the Kinesis client lib config with timestamp
// if InitialPositionInStream.AT_TIMESTAMP is passed
@@ -211,7 +216,7 @@ private[kinesis] class KinesisReceiver[T](
* The KCL worker.shutdown() method stops the receiving/processing threads.
* The KCL will do its best to drain and checkpoint any in-flight records upon shutdown.
*/
- override def onStop() {
+ override def onStop(): Unit = {
if (workerThread != null) {
if (worker != null) {
worker.shutdown()
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
index 8c6a399dd763e..b35573e92e168 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
@@ -51,7 +51,7 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w
*
* @param shardId assigned by the KCL to this particular RecordProcessor.
*/
- override def initialize(shardId: String) {
+ override def initialize(shardId: String): Unit = {
this.shardId = shardId
logInfo(s"Initialized workerId $workerId with shardId $shardId")
}
@@ -65,7 +65,8 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w
* @param checkpointer used to update Kinesis when this batch has been processed/stored
* in the DStream
*/
- override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) {
+ override def processRecords(batch: List[Record],
+ checkpointer: IRecordProcessorCheckpointer): Unit = {
if (!receiver.isStopped()) {
try {
// Limit the number of processed records from Kinesis stream. This is because the KCL cannot
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
deleted file mode 100644
index c60b9896a3473..0000000000000
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
+++ /dev/null
@@ -1,632 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.streaming.kinesis
-
-import scala.reflect.ClassTag
-
-import com.amazonaws.regions.RegionUtils
-import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
-import com.amazonaws.services.kinesis.model.Record
-
-import org.apache.spark.api.java.function.{Function => JFunction}
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{Duration, StreamingContext}
-import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext}
-import org.apache.spark.streaming.dstream.ReceiverInputDStream
-
-object KinesisUtils {
- /**
- * Create an input stream that pulls messages from a Kinesis stream.
- * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
- *
- * @param ssc StreamingContext object
- * @param kinesisAppName Kinesis application name used by the Kinesis Client Library
- * (KCL) to update DynamoDB
- * @param streamName Kinesis stream name
- * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
- * @param regionName Name of region used by the Kinesis Client Library (KCL) to update
- * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
- * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
- * worker's initial starting position in the stream.
- * The values are either the beginning of the stream
- * per Kinesis' limit of 24 hours
- * (InitialPositionInStream.TRIM_HORIZON) or
- * the tip of the stream (InitialPositionInStream.LATEST).
- * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
- * See the Kinesis Spark Streaming documentation for more
- * details on the different types of checkpoints.
- * @param storageLevel Storage level to use for storing the received objects.
- * StorageLevel.MEMORY_AND_DISK_2 is recommended.
- * @param messageHandler A custom message handler that can generate a generic output from a
- * Kinesis `Record`, which contains both message data, and metadata.
- *
- * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain
- * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain
- * gets the AWS credentials.
- */
- @deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
- def createStream[T: ClassTag](
- ssc: StreamingContext,
- kinesisAppName: String,
- streamName: String,
- endpointUrl: String,
- regionName: String,
- initialPositionInStream: InitialPositionInStream,
- checkpointInterval: Duration,
- storageLevel: StorageLevel,
- messageHandler: Record => T): ReceiverInputDStream[T] = {
- val cleanedHandler = ssc.sc.clean(messageHandler)
- // Setting scope to override receiver stream's scope of "receiver stream"
- ssc.withNamedScope("kinesis stream") {
- new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName),
- KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream),
- kinesisAppName, checkpointInterval, storageLevel,
- cleanedHandler, DefaultCredentials, None, None)
- }
- }
-
- /**
- * Create an input stream that pulls messages from a Kinesis stream.
- * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
- *
- * @param ssc StreamingContext object
- * @param kinesisAppName Kinesis application name used by the Kinesis Client Library
- * (KCL) to update DynamoDB
- * @param streamName Kinesis stream name
- * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
- * @param regionName Name of region used by the Kinesis Client Library (KCL) to update
- * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
- * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
- * worker's initial starting position in the stream.
- * The values are either the beginning of the stream
- * per Kinesis' limit of 24 hours
- * (InitialPositionInStream.TRIM_HORIZON) or
- * the tip of the stream (InitialPositionInStream.LATEST).
- * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
- * See the Kinesis Spark Streaming documentation for more
- * details on the different types of checkpoints.
- * @param storageLevel Storage level to use for storing the received objects.
- * StorageLevel.MEMORY_AND_DISK_2 is recommended.
- * @param messageHandler A custom message handler that can generate a generic output from a
- * Kinesis `Record`, which contains both message data, and metadata.
- * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
- * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
- *
- * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing
- * is enabled. Make sure that your checkpoint directory is secure.
- */
- // scalastyle:off
- @deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
- def createStream[T: ClassTag](
- ssc: StreamingContext,
- kinesisAppName: String,
- streamName: String,
- endpointUrl: String,
- regionName: String,
- initialPositionInStream: InitialPositionInStream,
- checkpointInterval: Duration,
- storageLevel: StorageLevel,
- messageHandler: Record => T,
- awsAccessKeyId: String,
- awsSecretKey: String): ReceiverInputDStream[T] = {
- // scalastyle:on
- val cleanedHandler = ssc.sc.clean(messageHandler)
- ssc.withNamedScope("kinesis stream") {
- val kinesisCredsProvider = BasicCredentials(
- awsAccessKeyId = awsAccessKeyId,
- awsSecretKey = awsSecretKey)
- new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName),
- KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream),
- kinesisAppName, checkpointInterval, storageLevel,
- cleanedHandler, kinesisCredsProvider, None, None)
- }
- }
-
- /**
- * Create an input stream that pulls messages from a Kinesis stream.
- * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
- *
- * @param ssc StreamingContext object
- * @param kinesisAppName Kinesis application name used by the Kinesis Client Library
- * (KCL) to update DynamoDB
- * @param streamName Kinesis stream name
- * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
- * @param regionName Name of region used by the Kinesis Client Library (KCL) to update
- * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
- * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
- * worker's initial starting position in the stream.
- * The values are either the beginning of the stream
- * per Kinesis' limit of 24 hours
- * (InitialPositionInStream.TRIM_HORIZON) or
- * the tip of the stream (InitialPositionInStream.LATEST).
- * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
- * See the Kinesis Spark Streaming documentation for more
- * details on the different types of checkpoints.
- * @param storageLevel Storage level to use for storing the received objects.
- * StorageLevel.MEMORY_AND_DISK_2 is recommended.
- * @param messageHandler A custom message handler that can generate a generic output from a
- * Kinesis `Record`, which contains both message data, and metadata.
- * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
- * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
- * @param stsAssumeRoleArn ARN of IAM role to assume when using STS sessions to read from
- * Kinesis stream.
- * @param stsSessionName Name to uniquely identify STS sessions if multiple principals assume
- * the same role.
- * @param stsExternalId External ID that can be used to validate against the assumed IAM role's
- * trust policy.
- *
- * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing
- * is enabled. Make sure that your checkpoint directory is secure.
- */
- // scalastyle:off
- @deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
- def createStream[T: ClassTag](
- ssc: StreamingContext,
- kinesisAppName: String,
- streamName: String,
- endpointUrl: String,
- regionName: String,
- initialPositionInStream: InitialPositionInStream,
- checkpointInterval: Duration,
- storageLevel: StorageLevel,
- messageHandler: Record => T,
- awsAccessKeyId: String,
- awsSecretKey: String,
- stsAssumeRoleArn: String,
- stsSessionName: String,
- stsExternalId: String): ReceiverInputDStream[T] = {
- // scalastyle:on
- val cleanedHandler = ssc.sc.clean(messageHandler)
- ssc.withNamedScope("kinesis stream") {
- val kinesisCredsProvider = STSCredentials(
- stsRoleArn = stsAssumeRoleArn,
- stsSessionName = stsSessionName,
- stsExternalId = Option(stsExternalId),
- longLivedCreds = BasicCredentials(
- awsAccessKeyId = awsAccessKeyId,
- awsSecretKey = awsSecretKey))
- new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName),
- KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream),
- kinesisAppName, checkpointInterval, storageLevel,
- cleanedHandler, kinesisCredsProvider, None, None)
- }
- }
-
- /**
- * Create an input stream that pulls messages from a Kinesis stream.
- * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
- *
- * @param ssc StreamingContext object
- * @param kinesisAppName Kinesis application name used by the Kinesis Client Library
- * (KCL) to update DynamoDB
- * @param streamName Kinesis stream name
- * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
- * @param regionName Name of region used by the Kinesis Client Library (KCL) to update
- * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
- * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
- * worker's initial starting position in the stream.
- * The values are either the beginning of the stream
- * per Kinesis' limit of 24 hours
- * (InitialPositionInStream.TRIM_HORIZON) or
- * the tip of the stream (InitialPositionInStream.LATEST).
- * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
- * See the Kinesis Spark Streaming documentation for more
- * details on the different types of checkpoints.
- * @param storageLevel Storage level to use for storing the received objects.
- * StorageLevel.MEMORY_AND_DISK_2 is recommended.
- *
- * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain
- * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain
- * gets the AWS credentials.
- */
- @deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
- def createStream(
- ssc: StreamingContext,
- kinesisAppName: String,
- streamName: String,
- endpointUrl: String,
- regionName: String,
- initialPositionInStream: InitialPositionInStream,
- checkpointInterval: Duration,
- storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = {
- // Setting scope to override receiver stream's scope of "receiver stream"
- ssc.withNamedScope("kinesis stream") {
- new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName),
- KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream),
- kinesisAppName, checkpointInterval, storageLevel,
- KinesisInputDStream.defaultMessageHandler, DefaultCredentials, None, None)
- }
- }
-
- /**
- * Create an input stream that pulls messages from a Kinesis stream.
- * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
- *
- * @param ssc StreamingContext object
- * @param kinesisAppName Kinesis application name used by the Kinesis Client Library
- * (KCL) to update DynamoDB
- * @param streamName Kinesis stream name
- * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
- * @param regionName Name of region used by the Kinesis Client Library (KCL) to update
- * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
- * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
- * worker's initial starting position in the stream.
- * The values are either the beginning of the stream
- * per Kinesis' limit of 24 hours
- * (InitialPositionInStream.TRIM_HORIZON) or
- * the tip of the stream (InitialPositionInStream.LATEST).
- * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
- * See the Kinesis Spark Streaming documentation for more
- * details on the different types of checkpoints.
- * @param storageLevel Storage level to use for storing the received objects.
- * StorageLevel.MEMORY_AND_DISK_2 is recommended.
- * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
- * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
- *
- * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing
- * is enabled. Make sure that your checkpoint directory is secure.
- */
- @deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
- def createStream(
- ssc: StreamingContext,
- kinesisAppName: String,
- streamName: String,
- endpointUrl: String,
- regionName: String,
- initialPositionInStream: InitialPositionInStream,
- checkpointInterval: Duration,
- storageLevel: StorageLevel,
- awsAccessKeyId: String,
- awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = {
- ssc.withNamedScope("kinesis stream") {
- val kinesisCredsProvider = BasicCredentials(
- awsAccessKeyId = awsAccessKeyId,
- awsSecretKey = awsSecretKey)
- new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName),
- KinesisInitialPositions.fromKinesisInitialPosition(initialPositionInStream),
- kinesisAppName, checkpointInterval, storageLevel,
- KinesisInputDStream.defaultMessageHandler, kinesisCredsProvider, None, None)
- }
- }
-
- /**
- * Create an input stream that pulls messages from a Kinesis stream.
- * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
- *
- * @param jssc Java StreamingContext object
- * @param kinesisAppName Kinesis application name used by the Kinesis Client Library
- * (KCL) to update DynamoDB
- * @param streamName Kinesis stream name
- * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
- * @param regionName Name of region used by the Kinesis Client Library (KCL) to update
- * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
- * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
- * worker's initial starting position in the stream.
- * The values are either the beginning of the stream
- * per Kinesis' limit of 24 hours
- * (InitialPositionInStream.TRIM_HORIZON) or
- * the tip of the stream (InitialPositionInStream.LATEST).
- * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
- * See the Kinesis Spark Streaming documentation for more
- * details on the different types of checkpoints.
- * @param storageLevel Storage level to use for storing the received objects.
- * StorageLevel.MEMORY_AND_DISK_2 is recommended.
- * @param messageHandler A custom message handler that can generate a generic output from a
- * Kinesis `Record`, which contains both message data, and metadata.
- * @param recordClass Class of the records in DStream
- *
- * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain
- * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain
- * gets the AWS credentials.
- */
- @deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
- def createStream[T](
- jssc: JavaStreamingContext,
- kinesisAppName: String,
- streamName: String,
- endpointUrl: String,
- regionName: String,
- initialPositionInStream: InitialPositionInStream,
- checkpointInterval: Duration,
- storageLevel: StorageLevel,
- messageHandler: JFunction[Record, T],
- recordClass: Class[T]): JavaReceiverInputDStream[T] = {
- implicit val recordCmt: ClassTag[T] = ClassTag(recordClass)
- val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_))
- createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
- initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler)
- }
-
- /**
- * Create an input stream that pulls messages from a Kinesis stream.
- * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
- *
- * @param jssc Java StreamingContext object
- * @param kinesisAppName Kinesis application name used by the Kinesis Client Library
- * (KCL) to update DynamoDB
- * @param streamName Kinesis stream name
- * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
- * @param regionName Name of region used by the Kinesis Client Library (KCL) to update
- * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
- * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
- * worker's initial starting position in the stream.
- * The values are either the beginning of the stream
- * per Kinesis' limit of 24 hours
- * (InitialPositionInStream.TRIM_HORIZON) or
- * the tip of the stream (InitialPositionInStream.LATEST).
- * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
- * See the Kinesis Spark Streaming documentation for more
- * details on the different types of checkpoints.
- * @param storageLevel Storage level to use for storing the received objects.
- * StorageLevel.MEMORY_AND_DISK_2 is recommended.
- * @param messageHandler A custom message handler that can generate a generic output from a
- * Kinesis `Record`, which contains both message data, and metadata.
- * @param recordClass Class of the records in DStream
- * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
- * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
- *
- * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing
- * is enabled. Make sure that your checkpoint directory is secure.
- */
- // scalastyle:off
- @deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
- def createStream[T](
- jssc: JavaStreamingContext,
- kinesisAppName: String,
- streamName: String,
- endpointUrl: String,
- regionName: String,
- initialPositionInStream: InitialPositionInStream,
- checkpointInterval: Duration,
- storageLevel: StorageLevel,
- messageHandler: JFunction[Record, T],
- recordClass: Class[T],
- awsAccessKeyId: String,
- awsSecretKey: String): JavaReceiverInputDStream[T] = {
- // scalastyle:on
- implicit val recordCmt: ClassTag[T] = ClassTag(recordClass)
- val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_))
- createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
- initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler,
- awsAccessKeyId, awsSecretKey)
- }
-
- /**
- * Create an input stream that pulls messages from a Kinesis stream.
- * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
- *
- * @param jssc Java StreamingContext object
- * @param kinesisAppName Kinesis application name used by the Kinesis Client Library
- * (KCL) to update DynamoDB
- * @param streamName Kinesis stream name
- * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
- * @param regionName Name of region used by the Kinesis Client Library (KCL) to update
- * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
- * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
- * worker's initial starting position in the stream.
- * The values are either the beginning of the stream
- * per Kinesis' limit of 24 hours
- * (InitialPositionInStream.TRIM_HORIZON) or
- * the tip of the stream (InitialPositionInStream.LATEST).
- * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
- * See the Kinesis Spark Streaming documentation for more
- * details on the different types of checkpoints.
- * @param storageLevel Storage level to use for storing the received objects.
- * StorageLevel.MEMORY_AND_DISK_2 is recommended.
- * @param messageHandler A custom message handler that can generate a generic output from a
- * Kinesis `Record`, which contains both message data, and metadata.
- * @param recordClass Class of the records in DStream
- * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
- * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
- * @param stsAssumeRoleArn ARN of IAM role to assume when using STS sessions to read from
- * Kinesis stream.
- * @param stsSessionName Name to uniquely identify STS sessions if multiple princpals assume
- * the same role.
- * @param stsExternalId External ID that can be used to validate against the assumed IAM role's
- * trust policy.
- *
- * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing
- * is enabled. Make sure that your checkpoint directory is secure.
- */
- // scalastyle:off
- @deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
- def createStream[T](
- jssc: JavaStreamingContext,
- kinesisAppName: String,
- streamName: String,
- endpointUrl: String,
- regionName: String,
- initialPositionInStream: InitialPositionInStream,
- checkpointInterval: Duration,
- storageLevel: StorageLevel,
- messageHandler: JFunction[Record, T],
- recordClass: Class[T],
- awsAccessKeyId: String,
- awsSecretKey: String,
- stsAssumeRoleArn: String,
- stsSessionName: String,
- stsExternalId: String): JavaReceiverInputDStream[T] = {
- // scalastyle:on
- implicit val recordCmt: ClassTag[T] = ClassTag(recordClass)
- val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_))
- createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
- initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler,
- awsAccessKeyId, awsSecretKey, stsAssumeRoleArn, stsSessionName, stsExternalId)
- }
-
- /**
- * Create an input stream that pulls messages from a Kinesis stream.
- * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
- *
- * @param jssc Java StreamingContext object
- * @param kinesisAppName Kinesis application name used by the Kinesis Client Library
- * (KCL) to update DynamoDB
- * @param streamName Kinesis stream name
- * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
- * @param regionName Name of region used by the Kinesis Client Library (KCL) to update
- * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
- * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
- * worker's initial starting position in the stream.
- * The values are either the beginning of the stream
- * per Kinesis' limit of 24 hours
- * (InitialPositionInStream.TRIM_HORIZON) or
- * the tip of the stream (InitialPositionInStream.LATEST).
- * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
- * See the Kinesis Spark Streaming documentation for more
- * details on the different types of checkpoints.
- * @param storageLevel Storage level to use for storing the received objects.
- * StorageLevel.MEMORY_AND_DISK_2 is recommended.
- *
- * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain
- * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain
- * gets the AWS credentials.
- */
- @deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
- def createStream(
- jssc: JavaStreamingContext,
- kinesisAppName: String,
- streamName: String,
- endpointUrl: String,
- regionName: String,
- initialPositionInStream: InitialPositionInStream,
- checkpointInterval: Duration,
- storageLevel: StorageLevel
- ): JavaReceiverInputDStream[Array[Byte]] = {
- createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
- initialPositionInStream, checkpointInterval, storageLevel,
- KinesisInputDStream.defaultMessageHandler(_))
- }
-
- /**
- * Create an input stream that pulls messages from a Kinesis stream.
- * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
- *
- * @param jssc Java StreamingContext object
- * @param kinesisAppName Kinesis application name used by the Kinesis Client Library
- * (KCL) to update DynamoDB
- * @param streamName Kinesis stream name
- * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
- * @param regionName Name of region used by the Kinesis Client Library (KCL) to update
- * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
- * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
- * worker's initial starting position in the stream.
- * The values are either the beginning of the stream
- * per Kinesis' limit of 24 hours
- * (InitialPositionInStream.TRIM_HORIZON) or
- * the tip of the stream (InitialPositionInStream.LATEST).
- * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
- * See the Kinesis Spark Streaming documentation for more
- * details on the different types of checkpoints.
- * @param storageLevel Storage level to use for storing the received objects.
- * StorageLevel.MEMORY_AND_DISK_2 is recommended.
- * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
- * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
- *
- * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing
- * is enabled. Make sure that your checkpoint directory is secure.
- */
- @deprecated("Use KinesisInputDStream.builder instead", "2.2.0")
- def createStream(
- jssc: JavaStreamingContext,
- kinesisAppName: String,
- streamName: String,
- endpointUrl: String,
- regionName: String,
- initialPositionInStream: InitialPositionInStream,
- checkpointInterval: Duration,
- storageLevel: StorageLevel,
- awsAccessKeyId: String,
- awsSecretKey: String): JavaReceiverInputDStream[Array[Byte]] = {
- createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
- initialPositionInStream, checkpointInterval, storageLevel,
- KinesisInputDStream.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey)
- }
-
- private def validateRegion(regionName: String): String = {
- Option(RegionUtils.getRegion(regionName)).map { _.getName }.getOrElse {
- throw new IllegalArgumentException(s"Region name '$regionName' is not valid")
- }
- }
-}
-
-/**
- * This is a helper class that wraps the methods in KinesisUtils into more Python-friendly class and
- * function so that it can be easily instantiated and called from Python's KinesisUtils.
- */
-private class KinesisUtilsPythonHelper {
-
- def getInitialPositionInStream(initialPositionInStream: Int): InitialPositionInStream = {
- initialPositionInStream match {
- case 0 => InitialPositionInStream.LATEST
- case 1 => InitialPositionInStream.TRIM_HORIZON
- case _ => throw new IllegalArgumentException(
- "Illegal InitialPositionInStream. Please use " +
- "InitialPositionInStream.LATEST or InitialPositionInStream.TRIM_HORIZON")
- }
- }
-
- // scalastyle:off
- def createStream(
- jssc: JavaStreamingContext,
- kinesisAppName: String,
- streamName: String,
- endpointUrl: String,
- regionName: String,
- initialPositionInStream: Int,
- checkpointInterval: Duration,
- storageLevel: StorageLevel,
- awsAccessKeyId: String,
- awsSecretKey: String,
- stsAssumeRoleArn: String,
- stsSessionName: String,
- stsExternalId: String): JavaReceiverInputDStream[Array[Byte]] = {
- // scalastyle:on
- if (!(stsAssumeRoleArn != null && stsSessionName != null && stsExternalId != null)
- && !(stsAssumeRoleArn == null && stsSessionName == null && stsExternalId == null)) {
- throw new IllegalArgumentException("stsAssumeRoleArn, stsSessionName, and stsExtenalId " +
- "must all be defined or all be null")
- }
-
- if (stsAssumeRoleArn != null && stsSessionName != null && stsExternalId != null) {
- validateAwsCreds(awsAccessKeyId, awsSecretKey)
- KinesisUtils.createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
- getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel,
- KinesisInputDStream.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey,
- stsAssumeRoleArn, stsSessionName, stsExternalId)
- } else {
- validateAwsCreds(awsAccessKeyId, awsSecretKey)
- if (awsAccessKeyId == null && awsSecretKey == null) {
- KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName,
- getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel)
- } else {
- KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName,
- getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel,
- awsAccessKeyId, awsSecretKey)
- }
- }
- }
-
- // Throw IllegalArgumentException unless both values are null or neither are.
- private def validateAwsCreds(awsAccessKeyId: String, awsSecretKey: String) {
- if (awsAccessKeyId == null && awsSecretKey != null) {
- throw new IllegalArgumentException("awsSecretKey is set but awsAccessKeyId is null")
- }
- if (awsAccessKeyId != null && awsSecretKey == null) {
- throw new IllegalArgumentException("awsAccessKeyId is set but awsSecretKey is null")
- }
- }
-}
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtilsPythonHelper.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtilsPythonHelper.scala
new file mode 100644
index 0000000000000..c89dedd3366d1
--- /dev/null
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtilsPythonHelper.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.streaming.kinesis
+
+import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.Duration
+import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext}
+
+/**
+ * This is a helper class that wraps the methods in KinesisUtils into more Python-friendly class and
+ * function so that it can be easily instantiated and called from Python's KinesisUtils.
+ */
+private class KinesisUtilsPythonHelper {
+
+ // scalastyle:off
+ def createStream(
+ jssc: JavaStreamingContext,
+ kinesisAppName: String,
+ streamName: String,
+ endpointUrl: String,
+ regionName: String,
+ initialPositionInStream: Int,
+ checkpointInterval: Duration,
+ storageLevel: StorageLevel,
+ awsAccessKeyId: String,
+ awsSecretKey: String,
+ stsAssumeRoleArn: String,
+ stsSessionName: String,
+ stsExternalId: String): JavaReceiverInputDStream[Array[Byte]] = {
+ // scalastyle:on
+ if (!(stsAssumeRoleArn != null && stsSessionName != null && stsExternalId != null)
+ && !(stsAssumeRoleArn == null && stsSessionName == null && stsExternalId == null)) {
+ throw new IllegalArgumentException("stsAssumeRoleArn, stsSessionName, and stsExtenalId " +
+ "must all be defined or all be null")
+ }
+ if (awsAccessKeyId == null && awsSecretKey != null) {
+ throw new IllegalArgumentException("awsSecretKey is set but awsAccessKeyId is null")
+ }
+ if (awsAccessKeyId != null && awsSecretKey == null) {
+ throw new IllegalArgumentException("awsAccessKeyId is set but awsSecretKey is null")
+ }
+
+ val kinesisInitialPosition = initialPositionInStream match {
+ case 0 => InitialPositionInStream.LATEST
+ case 1 => InitialPositionInStream.TRIM_HORIZON
+ case _ => throw new IllegalArgumentException(
+ "Illegal InitialPositionInStream. Please use " +
+ "InitialPositionInStream.LATEST or InitialPositionInStream.TRIM_HORIZON")
+ }
+
+ val builder = KinesisInputDStream.builder.
+ streamingContext(jssc).
+ checkpointAppName(kinesisAppName).
+ streamName(streamName).
+ endpointUrl(endpointUrl).
+ regionName(regionName).
+ initialPosition(KinesisInitialPositions.fromKinesisInitialPosition(kinesisInitialPosition)).
+ checkpointInterval(checkpointInterval).
+ storageLevel(storageLevel)
+
+ if (stsAssumeRoleArn != null && stsSessionName != null && stsExternalId != null) {
+ val kinesisCredsProvider = STSCredentials(
+ stsAssumeRoleArn, stsSessionName, Option(stsExternalId),
+ BasicCredentials(awsAccessKeyId, awsSecretKey))
+ builder.
+ kinesisCredentials(kinesisCredsProvider).
+ buildWithMessageHandler(KinesisInputDStream.defaultMessageHandler)
+ } else {
+ if (awsAccessKeyId == null && awsSecretKey == null) {
+ builder.build()
+ } else {
+ builder.kinesisCredentials(BasicCredentials(awsAccessKeyId, awsSecretKey)).build()
+ }
+ }
+ }
+
+}
diff --git a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java
deleted file mode 100644
index b37b087467926..0000000000000
--- a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java
+++ /dev/null
@@ -1,98 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming.kinesis;
-
-import com.amazonaws.services.kinesis.model.Record;
-import org.junit.Test;
-
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.storage.StorageLevel;
-import org.apache.spark.streaming.Duration;
-import org.apache.spark.streaming.LocalJavaStreamingContext;
-import org.apache.spark.streaming.api.java.JavaDStream;
-
-import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream;
-
-/**
- * Demonstrate the use of the KinesisUtils Java API
- */
-public class JavaKinesisStreamSuite extends LocalJavaStreamingContext {
- @Test
- public void testKinesisStream() {
- String dummyEndpointUrl = KinesisTestUtils.defaultEndpointUrl();
- String dummyRegionName = KinesisTestUtils.getRegionNameByEndpoint(dummyEndpointUrl);
-
- // Tests the API, does not actually test data receiving
- JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream",
- dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, new Duration(2000),
- StorageLevel.MEMORY_AND_DISK_2());
- ssc.stop();
- }
-
- @Test
- public void testAwsCreds() {
- String dummyEndpointUrl = KinesisTestUtils.defaultEndpointUrl();
- String dummyRegionName = KinesisTestUtils.getRegionNameByEndpoint(dummyEndpointUrl);
-
- // Tests the API, does not actually test data receiving
- JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream",
- dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, new Duration(2000),
- StorageLevel.MEMORY_AND_DISK_2(), "fakeAccessKey", "fakeSecretKey");
- ssc.stop();
- }
-
- private static Function handler = new Function() {
- @Override
- public String call(Record record) {
- return record.getPartitionKey() + "-" + record.getSequenceNumber();
- }
- };
-
- @Test
- public void testCustomHandler() {
- // Tests the API, does not actually test data receiving
- JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream",
- "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST,
- new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class);
-
- ssc.stop();
- }
-
- @Test
- public void testCustomHandlerAwsCreds() {
- // Tests the API, does not actually test data receiving
- JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream",
- "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST,
- new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class,
- "fakeAccessKey", "fakeSecretKey");
-
- ssc.stop();
- }
-
- @Test
- public void testCustomHandlerAwsStsCreds() {
- // Tests the API, does not actually test data receiving
- JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream",
- "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST,
- new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class,
- "fakeAccessKey", "fakeSecretKey", "fakeSTSRoleArn", "fakeSTSSessionName",
- "fakeSTSExternalId");
-
- ssc.stop();
- }
-}
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
index ac0e6a8429d06..3e88e956ec237 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
@@ -28,7 +28,7 @@ import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
import org.scalatest.concurrent.Eventually
-import org.scalatest.mockito.MockitoSugar
+import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.streaming.{Duration, TestSuiteBase}
import org.apache.spark.util.ManualClock
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala
index 1c81298a7c201..8dc4de1aa3609 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala
@@ -27,7 +27,7 @@ trait KinesisFunSuite extends SparkFunSuite {
import KinesisTestUtils._
/** Run the test if environment variable is set or ignore the test */
- def testIfEnabled(testName: String)(testBody: => Unit) {
+ def testIfEnabled(testName: String)(testBody: => Unit): Unit = {
if (shouldRunTests) {
test(testName)(testBody)
} else {
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala
index 361520e292266..8b0d73c96da73 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala
@@ -19,9 +19,11 @@ package org.apache.spark.streaming.kinesis
import java.util.Calendar
-import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+import collection.JavaConverters._
+import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration}
+import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel
import org.scalatest.BeforeAndAfterEach
-import org.scalatest.mockito.MockitoSugar
+import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Duration, Seconds, StreamingContext, TestSuiteBase}
@@ -82,6 +84,8 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE
assert(dstream.kinesisCreds == DefaultCredentials)
assert(dstream.dynamoDBCreds == None)
assert(dstream.cloudWatchCreds == None)
+ assert(dstream.metricsLevel == DEFAULT_METRICS_LEVEL)
+ assert(dstream.metricsEnabledDimensions == DEFAULT_METRICS_ENABLED_DIMENSIONS)
}
test("should propagate custom non-auth values to KinesisInputDStream") {
@@ -94,6 +98,9 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE
val customKinesisCreds = mock[SparkAWSCredentials]
val customDynamoDBCreds = mock[SparkAWSCredentials]
val customCloudWatchCreds = mock[SparkAWSCredentials]
+ val customMetricsLevel = MetricsLevel.NONE
+ val customMetricsEnabledDimensions =
+ KinesisClientLibConfiguration.METRICS_ALWAYS_ENABLED_DIMENSIONS.asScala.toSet
val dstream = builder
.endpointUrl(customEndpointUrl)
@@ -105,6 +112,8 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE
.kinesisCredentials(customKinesisCreds)
.dynamoDBCredentials(customDynamoDBCreds)
.cloudWatchCredentials(customCloudWatchCreds)
+ .metricsLevel(customMetricsLevel)
+ .metricsEnabledDimensions(customMetricsEnabledDimensions)
.build()
assert(dstream.endpointUrl == customEndpointUrl)
assert(dstream.regionName == customRegion)
@@ -115,6 +124,8 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE
assert(dstream.kinesisCreds == customKinesisCreds)
assert(dstream.dynamoDBCreds == Option(customDynamoDBCreds))
assert(dstream.cloudWatchCreds == Option(customCloudWatchCreds))
+ assert(dstream.metricsLevel == customMetricsLevel)
+ assert(dstream.metricsEnabledDimensions == customMetricsEnabledDimensions)
// Testing with AtTimestamp
val cal = Calendar.getInstance()
@@ -132,6 +143,8 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE
.kinesisCredentials(customKinesisCreds)
.dynamoDBCredentials(customDynamoDBCreds)
.cloudWatchCredentials(customCloudWatchCreds)
+ .metricsLevel(customMetricsLevel)
+ .metricsEnabledDimensions(customMetricsEnabledDimensions)
.build()
assert(dstreamAtTimestamp.endpointUrl == customEndpointUrl)
assert(dstreamAtTimestamp.regionName == customRegion)
@@ -145,6 +158,8 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE
assert(dstreamAtTimestamp.kinesisCreds == customKinesisCreds)
assert(dstreamAtTimestamp.dynamoDBCreds == Option(customDynamoDBCreds))
assert(dstreamAtTimestamp.cloudWatchCreds == Option(customCloudWatchCreds))
+ assert(dstreamAtTimestamp.metricsLevel == customMetricsLevel)
+ assert(dstreamAtTimestamp.metricsEnabledDimensions == customMetricsEnabledDimensions)
}
test("old Api should throw UnsupportedOperationExceptionexception with AT_TIMESTAMP") {
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
index 52690847418ef..470a8cecc8fd9 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
@@ -27,7 +27,7 @@ import com.amazonaws.services.kinesis.model.Record
import org.mockito.ArgumentMatchers.{anyList, anyString, eq => meq}
import org.mockito.Mockito.{never, times, verify, when}
import org.scalatest.{BeforeAndAfter, Matchers}
-import org.scalatest.mockito.MockitoSugar
+import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.streaming.{Duration, TestSuiteBase}
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
index 51ee7fd213de5..eee62d25e62bb 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
@@ -21,7 +21,6 @@ import scala.collection.mutable
import scala.concurrent.duration._
import scala.util.Random
-import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
import com.amazonaws.services.kinesis.model.Record
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.scalatest.Matchers._
@@ -31,7 +30,7 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{StorageLevel, StreamBlockId}
-import org.apache.spark.streaming._
+import org.apache.spark.streaming.{LocalStreamingContext, _}
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.kinesis.KinesisInitialPositions.Latest
import org.apache.spark.streaming.kinesis.KinesisReadConfigurations._
@@ -41,7 +40,7 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
import org.apache.spark.util.Utils
abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFunSuite
- with Eventually with BeforeAndAfter with BeforeAndAfterAll {
+ with LocalStreamingContext with Eventually with BeforeAndAfter with BeforeAndAfterAll {
// This is the name that KCL will use to save metadata to DynamoDB
private val appName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}"
@@ -54,15 +53,9 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
private val dummyAWSSecretKey = "dummySecretKey"
private var testUtils: KinesisTestUtils = null
- private var ssc: StreamingContext = null
private var sc: SparkContext = null
override def beforeAll(): Unit = {
- val conf = new SparkConf()
- .setMaster("local[4]")
- .setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name
- sc = new SparkContext(conf)
-
runIfTestsEnabled("Prepare KinesisTestUtils") {
testUtils = new KPLBasedKinesisTestUtils()
testUtils.createStream()
@@ -71,12 +64,6 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
override def afterAll(): Unit = {
try {
- if (ssc != null) {
- ssc.stop()
- }
- if (sc != null) {
- sc.stop()
- }
if (testUtils != null) {
// Delete the Kinesis stream as well as the DynamoDB table generated by
// Kinesis Client Library when consuming the stream
@@ -88,34 +75,36 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
}
}
- before {
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ val conf = new SparkConf()
+ .setMaster("local[4]")
+ .setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name
+ sc = new SparkContext(conf)
ssc = new StreamingContext(sc, batchDuration)
}
- after {
- if (ssc != null) {
- ssc.stop(stopSparkContext = false)
- ssc = null
- }
- if (testUtils != null) {
- testUtils.deleteDynamoDBTable(appName)
+ override def afterEach(): Unit = {
+ try {
+ if (testUtils != null) {
+ testUtils.deleteDynamoDBTable(appName)
+ }
+ } finally {
+ super.afterEach()
}
}
- test("KinesisUtils API") {
- val kinesisStream1 = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream",
- dummyEndpointUrl, dummyRegionName,
- InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2)
- val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream",
- dummyEndpointUrl, dummyRegionName,
- InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2,
- dummyAWSAccessKey, dummyAWSSecretKey)
- }
-
test("RDD generation") {
- val inputStream = KinesisUtils.createStream(ssc, appName, "dummyStream",
- dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, Seconds(2),
- StorageLevel.MEMORY_AND_DISK_2, dummyAWSAccessKey, dummyAWSSecretKey)
+ val inputStream = KinesisInputDStream.builder.
+ streamingContext(ssc).
+ checkpointAppName(appName).
+ streamName("dummyStream").
+ endpointUrl(dummyEndpointUrl).
+ regionName(dummyRegionName).initialPosition(new Latest()).
+ checkpointInterval(Seconds(2)).
+ storageLevel(StorageLevel.MEMORY_AND_DISK_2).
+ kinesisCredentials(BasicCredentials(dummyAWSAccessKey, dummyAWSSecretKey)).
+ build()
assert(inputStream.isInstanceOf[KinesisInputDStream[Array[Byte]]])
val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream[Array[Byte]]]
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Edge.scala b/graphx/src/main/scala/org/apache/spark/graphx/Edge.scala
index ecc37dcaad1fe..d733868908350 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Edge.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Edge.scala
@@ -81,13 +81,13 @@ object Edge {
override def copyElement(
src: Array[Edge[ED]], srcPos: Int,
- dst: Array[Edge[ED]], dstPos: Int) {
+ dst: Array[Edge[ED]], dstPos: Int): Unit = {
dst(dstPos) = src(srcPos)
}
override def copyRange(
src: Array[Edge[ED]], srcPos: Int,
- dst: Array[Edge[ED]], dstPos: Int, length: Int) {
+ dst: Array[Edge[ED]], dstPos: Int, length: Int): Unit = {
System.arraycopy(src, srcPos, dst, dstPos, length)
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala
index ef0b943fc3c38..4ff5b02daecbe 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala
@@ -30,7 +30,7 @@ object GraphXUtils {
/**
* Registers classes that GraphX uses with Kryo.
*/
- def registerKryoClasses(conf: SparkConf) {
+ def registerKryoClasses(conf: SparkConf): Unit = {
conf.registerKryoClasses(Array(
classOf[Edge[Object]],
classOf[(VertexId, Object)],
@@ -54,7 +54,7 @@ object GraphXUtils {
mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
reduceFunc: (A, A) => A,
activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = {
- def sendMsg(ctx: EdgeContext[VD, ED, A]) {
+ def sendMsg(ctx: EdgeContext[VD, ED, A]): Unit = {
mapFunc(ctx.toEdgeTriplet).foreach { kv =>
val id = kv._1
val msg = kv._2
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
index 0e6a340a680ba..8d03112a1c3dc 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
@@ -222,7 +222,7 @@ class EdgePartition[
*
* @param f an external state mutating user defined function.
*/
- def foreach(f: Edge[ED] => Unit) {
+ def foreach(f: Edge[ED] => Unit): Unit = {
iterator.foreach(f)
}
@@ -495,7 +495,7 @@ private class AggregatingEdgeContext[VD, ED, A](
srcId: VertexId, dstId: VertexId,
localSrcId: Int, localDstId: Int,
srcAttr: VD, dstAttr: VD,
- attr: ED) {
+ attr: ED): Unit = {
_srcId = srcId
_dstId = dstId
_localSrcId = localSrcId
@@ -505,13 +505,13 @@ private class AggregatingEdgeContext[VD, ED, A](
_attr = attr
}
- def setSrcOnly(srcId: VertexId, localSrcId: Int, srcAttr: VD) {
+ def setSrcOnly(srcId: VertexId, localSrcId: Int, srcAttr: VD): Unit = {
_srcId = srcId
_localSrcId = localSrcId
_srcAttr = srcAttr
}
- def setRest(dstId: VertexId, localDstId: Int, dstAttr: VD, attr: ED) {
+ def setRest(dstId: VertexId, localDstId: Int, dstAttr: VD, attr: ED): Unit = {
_dstId = dstId
_localDstId = localDstId
_dstAttr = dstAttr
@@ -524,14 +524,14 @@ private class AggregatingEdgeContext[VD, ED, A](
override def dstAttr: VD = _dstAttr
override def attr: ED = _attr
- override def sendToSrc(msg: A) {
+ override def sendToSrc(msg: A): Unit = {
send(_localSrcId, msg)
}
- override def sendToDst(msg: A) {
+ override def sendToDst(msg: A): Unit = {
send(_localDstId, msg)
}
- @inline private def send(localId: Int, msg: A) {
+ @inline private def send(localId: Int, msg: A): Unit = {
if (bitset.get(localId)) {
aggregates(localId) = mergeMsg(aggregates(localId), msg)
} else {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
index 27c08c894a39f..c7868f85d1f76 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
@@ -30,7 +30,7 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla
private[this] val edges = new PrimitiveVector[Edge[ED]](size)
/** Add a new edge to the partition. */
- def add(src: VertexId, dst: VertexId, d: ED) {
+ def add(src: VertexId, dst: VertexId, d: ED): Unit = {
edges += Edge(src, dst, d)
}
@@ -90,7 +90,7 @@ class ExistingEdgePartitionBuilder[
private[this] val edges = new PrimitiveVector[EdgeWithLocalIds[ED]](size)
/** Add a new edge to the partition. */
- def add(src: VertexId, dst: VertexId, localSrc: Int, localDst: Int, d: ED) {
+ def add(src: VertexId, dst: VertexId, localSrc: Int, localDst: Int, d: ED): Unit = {
edges += EdgeWithLocalIds(src, dst, localSrc, localDst, d)
}
@@ -153,13 +153,13 @@ private[impl] object EdgeWithLocalIds {
override def copyElement(
src: Array[EdgeWithLocalIds[ED]], srcPos: Int,
- dst: Array[EdgeWithLocalIds[ED]], dstPos: Int) {
+ dst: Array[EdgeWithLocalIds[ED]], dstPos: Int): Unit = {
dst(dstPos) = src(srcPos)
}
override def copyRange(
src: Array[EdgeWithLocalIds[ED]], srcPos: Int,
- dst: Array[EdgeWithLocalIds[ED]], dstPos: Int, length: Int) {
+ dst: Array[EdgeWithLocalIds[ED]], dstPos: Int, length: Int): Unit = {
System.arraycopy(src, srcPos, dst, dstPos, length)
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
index 0a97ab492600d..8564597f4f135 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
@@ -103,15 +103,16 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
(part, (e.srcId, e.dstId, e.attr))
}
.partitionBy(new HashPartitioner(numPartitions))
- .mapPartitionsWithIndex( { (pid, iter) =>
- val builder = new EdgePartitionBuilder[ED, VD]()(edTag, vdTag)
- iter.foreach { message =>
- val data = message._2
- builder.add(data._1, data._2, data._3)
- }
- val edgePartition = builder.toEdgePartition
- Iterator((pid, edgePartition))
- }, preservesPartitioning = true)).cache()
+ .mapPartitionsWithIndex(
+ { (pid: Int, iter: Iterator[(PartitionID, (VertexId, VertexId, ED))]) =>
+ val builder = new EdgePartitionBuilder[ED, VD]()(edTag, vdTag)
+ iter.foreach { message =>
+ val data = message._2
+ builder.add(data._1, data._2, data._3)
+ }
+ val edgePartition = builder.toEdgePartition
+ Iterator((pid, edgePartition))
+ }, preservesPartitioning = true)).cache()
GraphImpl.fromExistingRDDs(vertices.withEdges(newEdges), newEdges)
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala
index d2194d85bf525..e0d4dd3248734 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala
@@ -58,7 +58,7 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag](
* `vertices`. This operation modifies the `ReplicatedVertexView`, and callers can access `edges`
* afterwards to obtain the upgraded view.
*/
- def upgrade(vertices: VertexRDD[VD], includeSrc: Boolean, includeDst: Boolean) {
+ def upgrade(vertices: VertexRDD[VD], includeSrc: Boolean, includeDst: Boolean): Unit = {
val shipSrc = includeSrc && !hasSrcId
val shipDst = includeDst && !hasDstId
if (shipSrc || shipDst) {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
index 6453bbeae9f10..bef380dc12c23 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
@@ -123,7 +123,7 @@ class RoutingTablePartition(
*/
def foreachWithinEdgePartition
(pid: PartitionID, includeSrc: Boolean, includeDst: Boolean)
- (f: VertexId => Unit) {
+ (f: VertexId => Unit): Unit = {
val (vidsCandidate, srcVids, dstVids) = routingTable(pid)
val size = vidsCandidate.length
if (includeSrc && includeDst) {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
index 2847a4e172d40..c508056fe3ae3 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
@@ -98,7 +98,7 @@ object SVDPlusPlus {
(ctx: EdgeContext[
(Array[Double], Array[Double], Double, Double),
Double,
- (Array[Double], Array[Double], Double)]) {
+ (Array[Double], Array[Double], Double)]): Unit = {
val (usr, itm) = (ctx.srcAttr, ctx.dstAttr)
val (p, q) = (usr._1, itm._1)
val rank = p.length
@@ -177,7 +177,7 @@ object SVDPlusPlus {
// calculate error on training set
def sendMsgTestF(conf: Conf, u: Double)
- (ctx: EdgeContext[(Array[Double], Array[Double], Double, Double), Double, Double]) {
+ (ctx: EdgeContext[(Array[Double], Array[Double], Double, Double), Double, Double]): Unit = {
val (usr, itm) = (ctx.srcAttr, ctx.dstAttr)
val (p, q) = (usr._1, itm._1)
var pred = u + usr._3 + itm._3 + blas.ddot(q.length, q, 1, usr._2, 1)
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
index 2715137d19ebc..211b4d6e4c5d3 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
@@ -85,7 +85,7 @@ object TriangleCount {
}
// Edge function computes intersection of smaller vertex with larger vertex
- def edgeFunc(ctx: EdgeContext[VertexSet, ED, Int]) {
+ def edgeFunc(ctx: EdgeContext[VertexSet, ED, Int]): Unit = {
val (smallSet, largeSet) = if (ctx.srcAttr.size < ctx.dstAttr.size) {
(ctx.srcAttr, ctx.dstAttr)
} else {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala
index 5ece5ae5c359b..dc3cdc452a389 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala
@@ -118,7 +118,7 @@ private[graphx] object BytecodeUtils {
if (name == methodName) {
new MethodVisitor(ASM7) {
override def visitMethodInsn(
- op: Int, owner: String, name: String, desc: String, itf: Boolean) {
+ op: Int, owner: String, name: String, desc: String, itf: Boolean): Unit = {
if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) {
if (!skipClass(owner)) {
methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name))
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala
index 972237da1cb28..e3b283649cb2c 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala
@@ -71,7 +71,7 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
}
/** Set the value for a key */
- def update(k: K, v: V) {
+ def update(k: K, v: V): Unit = {
val pos = keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
_values(pos) = v
keySet.rehashIfNeeded(k, grow, move)
@@ -80,7 +80,7 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
/** Set the value for a key */
- def setMerge(k: K, v: V, mergeF: (V, V) => V) {
+ def setMerge(k: K, v: V, mergeF: (V, V) => V): Unit = {
val pos = keySet.addWithoutResize(k)
val ind = pos & OpenHashSet.POSITION_MASK
if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { // if first add
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
index 84940d96b563f..32844104c1deb 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
@@ -26,8 +26,11 @@
import java.util.Map;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.logging.Level;
+import java.util.logging.Logger;
import static org.apache.spark.launcher.CommandBuilderUtils.*;
+import static org.apache.spark.launcher.CommandBuilderUtils.join;
/**
* Launcher for Spark applications.
@@ -38,6 +41,8 @@
*/
public class SparkLauncher extends AbstractLauncher {
+ private static final Logger LOG = Logger.getLogger(SparkLauncher.class.getName());
+
/** The Spark master. */
public static final String SPARK_MASTER = "spark.master";
@@ -363,6 +368,9 @@ public SparkAppHandle startApplication(SparkAppHandle.Listener... listeners) thr
String loggerName = getLoggerName();
ProcessBuilder pb = createBuilder();
+ if (LOG.isLoggable(Level.FINE)) {
+ LOG.fine(String.format("Launching Spark application:%n%s", join(" ", pb.command())));
+ }
boolean outputToLog = outputStream == null;
boolean errorToLog = !redirectErrorStream && errorStream == null;
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
index 3479e0c3422bd..383c3f60a595b 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
@@ -348,7 +348,7 @@ private List buildPySparkShellCommand(Map env) throws IO
}
private List buildSparkRCommand(Map env) throws IOException {
- if (!appArgs.isEmpty() && appArgs.get(0).endsWith(".R")) {
+ if (!appArgs.isEmpty() && (appArgs.get(0).endsWith(".R") || appArgs.get(0).endsWith(".r"))) {
System.err.println(
"Running R applications through 'sparkR' is not supported as of Spark 2.0.\n" +
"Use ./bin/spark-submit ");
@@ -390,9 +390,7 @@ boolean isClientMode(Map userProps) {
String userMaster = firstNonEmpty(master, userProps.get(SparkLauncher.SPARK_MASTER));
String userDeployMode = firstNonEmpty(deployMode, userProps.get(SparkLauncher.DEPLOY_MODE));
// Default master is "local[*]", so assume client mode in that case
- return userMaster == null ||
- "client".equals(userDeployMode) ||
- (!userMaster.equals("yarn-cluster") && userDeployMode == null);
+ return userMaster == null || userDeployMode == null || "client".equals(userDeployMode);
}
/**
diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
index 32a91b1789412..752e8d4c23f8b 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
@@ -250,6 +250,26 @@ public void testMissingAppResource() {
new SparkSubmitCommandBuilder().buildSparkSubmitArgs();
}
+ @Test
+ public void testIsClientMode() {
+ // Default master is "local[*]"
+ SparkSubmitCommandBuilder builder = newCommandBuilder(Collections.emptyList());
+ assertTrue("By default application run in local mode",
+ builder.isClientMode(Collections.emptyMap()));
+ // --master yarn or it can be any RM
+ List sparkSubmitArgs = Arrays.asList(parser.MASTER, "yarn");
+ builder = newCommandBuilder(sparkSubmitArgs);
+ assertTrue("By default deploy mode is client", builder.isClientMode(Collections.emptyMap()));
+ // --master yarn and set spark.submit.deployMode to client
+ Map userProps = new HashMap<>();
+ userProps.put("spark.submit.deployMode", "client");
+ assertTrue(builder.isClientMode(userProps));
+ // --master mesos --deploy-mode cluster
+ sparkSubmitArgs = Arrays.asList(parser.MASTER, "mesos", parser.DEPLOY_MODE, "cluster");
+ builder = newCommandBuilder(sparkSubmitArgs);
+ assertFalse(builder.isClientMode(Collections.emptyMap()));
+ }
+
private void testCmdBuilder(boolean isDriver, boolean useDefaultPropertyFile) throws Exception {
final String DRIVER_DEFAULT_PARAM = "-Ddriver-default";
final String DRIVER_EXTRA_PARAM = "-Ddriver-extra";
diff --git a/licenses-binary/LICENSE-JLargeArrays.txt b/licenses-binary/LICENSE-JLargeArrays.txt
new file mode 100644
index 0000000000000..304e724556984
--- /dev/null
+++ b/licenses-binary/LICENSE-JLargeArrays.txt
@@ -0,0 +1,23 @@
+JLargeArrays
+Copyright (C) 2013 onward University of Warsaw, ICM
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/licenses-binary/LICENSE-JTransforms.txt b/licenses-binary/LICENSE-JTransforms.txt
new file mode 100644
index 0000000000000..2f0589f76da7d
--- /dev/null
+++ b/licenses-binary/LICENSE-JTransforms.txt
@@ -0,0 +1,23 @@
+JTransforms
+Copyright (c) 2007 onward, Piotr Wendykier
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/licenses-binary/LICENSE-dnsjava.txt b/licenses-binary/LICENSE-dnsjava.txt
new file mode 100644
index 0000000000000..70ee5b12ff23f
--- /dev/null
+++ b/licenses-binary/LICENSE-dnsjava.txt
@@ -0,0 +1,24 @@
+Copyright (c) 1998-2011, Brian Wellington.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice,
+ this list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/licenses-binary/LICENSE-jtransforms.html b/licenses-binary/LICENSE-jtransforms.html
deleted file mode 100644
index 351c17412357b..0000000000000
--- a/licenses-binary/LICENSE-jtransforms.html
+++ /dev/null
@@ -1,388 +0,0 @@
-
-
-Mozilla Public License version 1.1
-
-
-
-
- Mozilla Public License Version 1.1
- 1. Definitions.
-
- 1.0.1. "Commercial Use"
- means distribution or otherwise making the Covered Code available to a third party.
- 1.1. "Contributor"
- means each entity that creates or contributes to the creation of Modifications.
- 1.2. "Contributor Version"
- means the combination of the Original Code, prior Modifications used by a Contributor,
- and the Modifications made by that particular Contributor.
- 1.3. "Covered Code"
- means the Original Code or Modifications or the combination of the Original Code and
- Modifications, in each case including portions thereof.
- 1.4. "Electronic Distribution Mechanism"
- means a mechanism generally accepted in the software development community for the
- electronic transfer of data.
- 1.5. "Executable"
- means Covered Code in any form other than Source Code.
- 1.6. "Initial Developer"
- means the individual or entity identified as the Initial Developer in the Source Code
- notice required by Exhibit A .
- 1.7. "Larger Work"
- means a work which combines Covered Code or portions thereof with code not governed
- by the terms of this License.
- 1.8. "License"
- means this document.
- 1.8.1. "Licensable"
- means having the right to grant, to the maximum extent possible, whether at the
- time of the initial grant or subsequently acquired, any and all of the rights
- conveyed herein.
- 1.9. "Modifications"
-
- means any addition to or deletion from the substance or structure of either the
- Original Code or any previous Modifications. When Covered Code is released as a
- series of files, a Modification is:
-
- Any addition to or deletion from the contents of a file
- containing Original Code or previous Modifications.
- Any new file that contains any part of the Original Code or
- previous Modifications.
-
- 1.10. "Original Code"
- means Source Code of computer software code which is described in the Source Code
- notice required by Exhibit A as Original Code, and which,
- at the time of its release under this License is not already Covered Code governed
- by this License.
- 1.10.1. "Patent Claims"
- means any patent claim(s), now owned or hereafter acquired, including without
- limitation, method, process, and apparatus claims, in any patent Licensable by
- grantor.
- 1.11. "Source Code"
- means the preferred form of the Covered Code for making modifications to it,
- including all modules it contains, plus any associated interface definition files,
- scripts used to control compilation and installation of an Executable, or source
- code differential comparisons against either the Original Code or another well known,
- available Covered Code of the Contributor's choice. The Source Code can be in a
- compressed or archival form, provided the appropriate decompression or de-archiving
- software is widely available for no charge.
- 1.12. "You" (or "Your")
- means an individual or a legal entity exercising rights under, and complying with
- all of the terms of, this License or a future version of this License issued under
- Section 6.1. For legal entities, "You" includes any entity
- which controls, is controlled by, or is under common control with You. For purposes of
- this definition, "control" means (a) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or otherwise, or (b)
- ownership of more than fifty percent (50%) of the outstanding shares or beneficial
- ownership of such entity.
-
- 2. Source Code License.
- 2.1. The Initial Developer Grant.
- The Initial Developer hereby grants You a world-wide, royalty-free, non-exclusive
- license, subject to third party intellectual property claims:
-
- under intellectual property rights (other than patent or
- trademark) Licensable by Initial Developer to use, reproduce, modify, display, perform,
- sublicense and distribute the Original Code (or portions thereof) with or without
- Modifications, and/or as part of a Larger Work; and
- under Patents Claims infringed by the making, using or selling
- of Original Code, to make, have made, use, practice, sell, and offer for sale, and/or
- otherwise dispose of the Original Code (or portions thereof).
- the licenses granted in this Section 2.1
- (a ) and (b ) are effective on
- the date Initial Developer first distributes Original Code under the terms of this
- License.
- Notwithstanding Section 2.1 (b )
- above, no patent license is granted: 1) for code that You delete from the Original Code;
- 2) separate from the Original Code; or 3) for infringements caused by: i) the
- modification of the Original Code or ii) the combination of the Original Code with other
- software or devices.
-
- 2.2. Contributor Grant.
- Subject to third party intellectual property claims, each Contributor hereby grants You
- a world-wide, royalty-free, non-exclusive license
-
- under intellectual property rights (other than patent or trademark)
- Licensable by Contributor, to use, reproduce, modify, display, perform, sublicense and
- distribute the Modifications created by such Contributor (or portions thereof) either on
- an unmodified basis, with other Modifications, as Covered Code and/or as part of a Larger
- Work; and
- under Patent Claims infringed by the making, using, or selling of
- Modifications made by that Contributor either alone and/or in combination with its
- Contributor Version (or portions of such combination), to make, use, sell, offer for
- sale, have made, and/or otherwise dispose of: 1) Modifications made by that Contributor
- (or portions thereof); and 2) the combination of Modifications made by that Contributor
- with its Contributor Version (or portions of such combination).
- the licenses granted in Sections 2.2
- (a ) and 2.2 (b ) are effective
- on the date Contributor first makes Commercial Use of the Covered Code.
- Notwithstanding Section 2.2 (b )
- above, no patent license is granted: 1) for any code that Contributor has deleted from
- the Contributor Version; 2) separate from the Contributor Version; 3) for infringements
- caused by: i) third party modifications of Contributor Version or ii) the combination of
- Modifications made by that Contributor with other software (except as part of the
- Contributor Version) or other devices; or 4) under Patent Claims infringed by Covered Code
- in the absence of Modifications made by that Contributor.
-
- 3. Distribution Obligations.
- 3.1. Application of License.
- The Modifications which You create or to which You contribute are governed by the terms
- of this License, including without limitation Section 2.2 . The
- Source Code version of Covered Code may be distributed only under the terms of this License
- or a future version of this License released under Section 6.1 ,
- and You must include a copy of this License with every copy of the Source Code You
- distribute. You may not offer or impose any terms on any Source Code version that alters or
- restricts the applicable version of this License or the recipients' rights hereunder.
- However, You may include an additional document offering the additional rights described in
- Section 3.5 .
-
3.2. Availability of Source Code.
- Any Modification which You create or to which You contribute must be made available in
- Source Code form under the terms of this License either on the same media as an Executable
- version or via an accepted Electronic Distribution Mechanism to anyone to whom you made an
- Executable version available; and if made available via Electronic Distribution Mechanism,
- must remain available for at least twelve (12) months after the date it initially became
- available, or at least six (6) months after a subsequent version of that particular
- Modification has been made available to such recipients. You are responsible for ensuring
- that the Source Code version remains available even if the Electronic Distribution
- Mechanism is maintained by a third party.
-
3.3. Description of Modifications.
- You must cause all Covered Code to which You contribute to contain a file documenting the
- changes You made to create that Covered Code and the date of any change. You must include a
- prominent statement that the Modification is derived, directly or indirectly, from Original
- Code provided by the Initial Developer and including the name of the Initial Developer in
- (a) the Source Code, and (b) in any notice in an Executable version or related documentation
- in which You describe the origin or ownership of the Covered Code.
-
3.4. Intellectual Property Matters
- (a) Third Party Claims
- If Contributor has knowledge that a license under a third party's intellectual property
- rights is required to exercise the rights granted by such Contributor under Sections
- 2.1 or 2.2 , Contributor must include a
- text file with the Source Code distribution titled "LEGAL" which describes the claim and the
- party making the claim in sufficient detail that a recipient will know whom to contact. If
- Contributor obtains such knowledge after the Modification is made available as described in
- Section 3.2 , Contributor shall promptly modify the LEGAL file in
- all copies Contributor makes available thereafter and shall take other steps (such as
- notifying appropriate mailing lists or newsgroups) reasonably calculated to inform those who
- received the Covered Code that new knowledge has been obtained.
-
(b) Contributor APIs
- If Contributor's Modifications include an application programming interface and Contributor
- has knowledge of patent licenses which are reasonably necessary to implement that
- API , Contributor must also include this information in the
- legal file.
-
(c) Representations.
- Contributor represents that, except as disclosed pursuant to Section 3.4
- (a ) above, Contributor believes that Contributor's Modifications
- are Contributor's original creation(s) and/or Contributor has sufficient rights to grant the
- rights conveyed by this License.
-
3.5. Required Notices.
- You must duplicate the notice in Exhibit A in each file of the
- Source Code. If it is not possible to put such notice in a particular Source Code file due to
- its structure, then You must include such notice in a location (such as a relevant directory)
- where a user would be likely to look for such a notice. If You created one or more
- Modification(s) You may add your name as a Contributor to the notice described in
- Exhibit A . You must also duplicate this License in any documentation
- for the Source Code where You describe recipients' rights or ownership rights relating to
- Covered Code. You may choose to offer, and to charge a fee for, warranty, support, indemnity
- or liability obligations to one or more recipients of Covered Code. However, You may do so
- only on Your own behalf, and not on behalf of the Initial Developer or any Contributor. You
- must make it absolutely clear than any such warranty, support, indemnity or liability
- obligation is offered by You alone, and You hereby agree to indemnify the Initial Developer
- and every Contributor for any liability incurred by the Initial Developer or such Contributor
- as a result of warranty, support, indemnity or liability terms You offer.
-
3.6. Distribution of Executable Versions.
- You may distribute Covered Code in Executable form only if the requirements of Sections
- 3.1 , 3.2 ,
- 3.3 , 3.4 and
- 3.5 have been met for that Covered Code, and if You include a
- notice stating that the Source Code version of the Covered Code is available under the terms
- of this License, including a description of how and where You have fulfilled the obligations
- of Section 3.2 . The notice must be conspicuously included in any
- notice in an Executable version, related documentation or collateral in which You describe
- recipients' rights relating to the Covered Code. You may distribute the Executable version of
- Covered Code or ownership rights under a license of Your choice, which may contain terms
- different from this License, provided that You are in compliance with the terms of this
- License and that the license for the Executable version does not attempt to limit or alter the
- recipient's rights in the Source Code version from the rights set forth in this License. If
- You distribute the Executable version under a different license You must make it absolutely
- clear that any terms which differ from this License are offered by You alone, not by the
- Initial Developer or any Contributor. You hereby agree to indemnify the Initial Developer and
- every Contributor for any liability incurred by the Initial Developer or such Contributor as
- a result of any such terms You offer.
-
3.7. Larger Works.
- You may create a Larger Work by combining Covered Code with other code not governed by the
- terms of this License and distribute the Larger Work as a single product. In such a case,
- You must make sure the requirements of this License are fulfilled for the Covered Code.
-
4. Inability to Comply Due to Statute or Regulation.
- If it is impossible for You to comply with any of the terms of this License with respect to
- some or all of the Covered Code due to statute, judicial order, or regulation then You must:
- (a) comply with the terms of this License to the maximum extent possible; and (b) describe
- the limitations and the code they affect. Such description must be included in the
- legal file described in Section
- 3.4 and must be included with all distributions of the Source Code.
- Except to the extent prohibited by statute or regulation, such description must be
- sufficiently detailed for a recipient of ordinary skill to be able to understand it.
-
5. Application of this License.
- This License applies to code to which the Initial Developer has attached the notice in
- Exhibit A and to related Covered Code.
-
6. Versions of the License.
- 6.1. New Versions
- Netscape Communications Corporation ("Netscape") may publish revised and/or new versions
- of the License from time to time. Each version will be given a distinguishing version number.
-
6.2. Effect of New Versions
- Once Covered Code has been published under a particular version of the License, You may
- always continue to use it under the terms of that version. You may also choose to use such
- Covered Code under the terms of any subsequent version of the License published by Netscape.
- No one other than Netscape has the right to modify the terms applicable to Covered Code
- created under this License.
-
6.3. Derivative Works
- If You create or use a modified version of this License (which you may only do in order to
- apply it to code which is not already Covered Code governed by this License), You must (a)
- rename Your license so that the phrases "Mozilla", "MOZILLAPL", "MOZPL", "Netscape", "MPL",
- "NPL" or any confusingly similar phrase do not appear in your license (except to note that
- your license differs from this License) and (b) otherwise make it clear that Your version of
- the license contains terms which differ from the Mozilla Public License and Netscape Public
- License. (Filling in the name of the Initial Developer, Original Code or Contributor in the
- notice described in Exhibit A shall not of themselves be deemed to
- be modifications of this License.)
-
7. Disclaimer of warranty
- Covered code is provided under this license on an "as is"
- basis, without warranty of any kind, either expressed or implied, including, without
- limitation, warranties that the covered code is free of defects, merchantable, fit for a
- particular purpose or non-infringing. The entire risk as to the quality and performance of
- the covered code is with you. Should any covered code prove defective in any respect, you
- (not the initial developer or any other contributor) assume the cost of any necessary
- servicing, repair or correction. This disclaimer of warranty constitutes an essential part
- of this license. No use of any covered code is authorized hereunder except under this
- disclaimer.
-
8. Termination
- 8.1. This License and the rights granted hereunder will terminate
- automatically if You fail to comply with terms herein and fail to cure such breach
- within 30 days of becoming aware of the breach. All sublicenses to the Covered Code which
- are properly granted shall survive any termination of this License. Provisions which, by
- their nature, must remain in effect beyond the termination of this License shall survive.
-
8.2. If You initiate litigation by asserting a patent infringement
- claim (excluding declatory judgment actions) against Initial Developer or a Contributor
- (the Initial Developer or Contributor against whom You file such action is referred to
- as "Participant") alleging that:
-
- such Participant's Contributor Version directly or indirectly
- infringes any patent, then any and all rights granted by such Participant to You under
- Sections 2.1 and/or 2.2 of this
- License shall, upon 60 days notice from Participant terminate prospectively, unless if
- within 60 days after receipt of notice You either: (i) agree in writing to pay
- Participant a mutually agreeable reasonable royalty for Your past and future use of
- Modifications made by such Participant, or (ii) withdraw Your litigation claim with
- respect to the Contributor Version against such Participant. If within 60 days of
- notice, a reasonable royalty and payment arrangement are not mutually agreed upon in
- writing by the parties or the litigation claim is not withdrawn, the rights granted by
- Participant to You under Sections 2.1 and/or
- 2.2 automatically terminate at the expiration of the 60 day
- notice period specified above.
- any software, hardware, or device, other than such Participant's
- Contributor Version, directly or indirectly infringes any patent, then any rights
- granted to You by such Participant under Sections 2.1(b )
- and 2.2(b ) are revoked effective as of the date You first
- made, used, sold, distributed, or had made, Modifications made by that Participant.
-
- 8.3. If You assert a patent infringement claim against Participant
- alleging that such Participant's Contributor Version directly or indirectly infringes
- any patent where such claim is resolved (such as by license or settlement) prior to the
- initiation of patent infringement litigation, then the reasonable value of the licenses
- granted by such Participant under Sections 2.1 or
- 2.2 shall be taken into account in determining the amount or
- value of any payment or license.
-
8.4. In the event of termination under Sections
- 8.1 or 8.2 above, all end user
- license agreements (excluding distributors and resellers) which have been validly
- granted by You or any distributor hereunder prior to termination shall survive
- termination.
-
9. Limitation of liability
- Under no circumstances and under no legal theory, whether
- tort (including negligence), contract, or otherwise, shall you, the initial developer,
- any other contributor, or any distributor of covered code, or any supplier of any of
- such parties, be liable to any person for any indirect, special, incidental, or
- consequential damages of any character including, without limitation, damages for loss
- of goodwill, work stoppage, computer failure or malfunction, or any and all other
- commercial damages or losses, even if such party shall have been informed of the
- possibility of such damages. This limitation of liability shall not apply to liability
- for death or personal injury resulting from such party's negligence to the extent
- applicable law prohibits such limitation. Some jurisdictions do not allow the exclusion
- or limitation of incidental or consequential damages, so this exclusion and limitation
- may not apply to you.
-
10. U.S. government end users
- The Covered Code is a "commercial item," as that term is defined in 48
- C.F.R. 2.101 (Oct. 1995), consisting of
- "commercial computer software" and "commercial computer software documentation," as such
- terms are used in 48 C.F.R. 12.212 (Sept.
- 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R.
- 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users
- acquire Covered Code with only those rights set forth herein.
-
11. Miscellaneous
- This License represents the complete agreement concerning subject matter hereof. If
- any provision of this License is held to be unenforceable, such provision shall be
- reformed only to the extent necessary to make it enforceable. This License shall be
- governed by California law provisions (except to the extent applicable law, if any,
- provides otherwise), excluding its conflict-of-law provisions. With respect to
- disputes in which at least one party is a citizen of, or an entity chartered or
- registered to do business in the United States of America, any litigation relating to
- this License shall be subject to the jurisdiction of the Federal Courts of the
- Northern District of California, with venue lying in Santa Clara County, California,
- with the losing party responsible for costs, including without limitation, court
- costs and reasonable attorneys' fees and expenses. The application of the United
- Nations Convention on Contracts for the International Sale of Goods is expressly
- excluded. Any law or regulation which provides that the language of a contract
- shall be construed against the drafter shall not apply to this License.
-
12. Responsibility for claims
- As between Initial Developer and the Contributors, each party is responsible for
- claims and damages arising, directly or indirectly, out of its utilization of rights
- under this License and You agree to work with Initial Developer and Contributors to
- distribute such responsibility on an equitable basis. Nothing herein is intended or
- shall be deemed to constitute any admission of liability.
-
13. Multiple-licensed code
- Initial Developer may designate portions of the Covered Code as
- "Multiple-Licensed". "Multiple-Licensed" means that the Initial Developer permits
- you to utilize portions of the Covered Code under Your choice of the MPL
- or the alternative licenses, if any, specified by the Initial Developer in the file
- described in Exhibit A .
-
Exhibit A - Mozilla Public License.
- "The contents of this file are subject to the Mozilla Public License
-Version 1.1 (the "License"); you may not use this file except in
-compliance with the License. You may obtain a copy of the License at
-http://www.mozilla.org/MPL/
-
-Software distributed under the License is distributed on an "AS IS"
-basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the
-License for the specific language governing rights and limitations
-under the License.
-
-The Original Code is JTransforms.
-
-The Initial Developer of the Original Code is
-Piotr Wendykier, Emory University.
-Portions created by the Initial Developer are Copyright (C) 2007-2009
-the Initial Developer. All Rights Reserved.
-
-Alternatively, the contents of this file may be used under the terms of
-either the GNU General Public License Version 2 or later (the "GPL"), or
-the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
-in which case the provisions of the GPL or the LGPL are applicable instead
-of those above. If you wish to allow use of your version of this file only
-under the terms of either the GPL or the LGPL, and not to allow others to
-use your version of this file under the terms of the MPL, indicate your
-decision by deleting the provisions above and replace them with the notice
-and other provisions required by the GPL or the LGPL. If you do not delete
-the provisions above, a recipient may use your version of this file under
-the terms of any one of the MPL, the GPL or the LGPL.
- NOTE: The text of this Exhibit A may differ slightly from the text of
- the notices in the Source Code files of the Original Code. You should
- use the text of this Exhibit A rather than the text found in the
- Original Code Source Code for Your Modifications.
-
-
\ No newline at end of file
diff --git a/licenses-binary/LICENSE-re2j.txt b/licenses-binary/LICENSE-re2j.txt
new file mode 100644
index 0000000000000..0dc3cd70bf1f7
--- /dev/null
+++ b/licenses-binary/LICENSE-re2j.txt
@@ -0,0 +1,32 @@
+This is a work derived from Russ Cox's RE2 in Go, whose license
+http://golang.org/LICENSE is as follows:
+
+Copyright (c) 2009 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in
+ the documentation and/or other materials provided with the
+ distribution.
+
+ * Neither the name of Google Inc. nor the names of its contributors
+ may be used to endorse or promote products derived from this
+ software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
index 2a0f8c11d0a50..e054a15fc9b75 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
@@ -302,7 +302,7 @@ private[spark] object BLAS extends Serializable {
* @param x the vector x that contains the n elements.
* @param A the symmetric matrix A. Size of n x n.
*/
- def syr(alpha: Double, x: Vector, A: DenseMatrix) {
+ def syr(alpha: Double, x: Vector, A: DenseMatrix): Unit = {
val mA = A.numRows
val nA = A.numCols
require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA")
@@ -316,7 +316,7 @@ private[spark] object BLAS extends Serializable {
}
}
- private def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
+ private def syr(alpha: Double, x: DenseVector, A: DenseMatrix): Unit = {
val nA = A.numRows
val mA = A.numCols
@@ -334,7 +334,7 @@ private[spark] object BLAS extends Serializable {
}
}
- private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) {
+ private def syr(alpha: Double, x: SparseVector, A: DenseMatrix): Unit = {
val mA = A.numCols
val xIndices = x.indices
val xValues = x.values
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
index 6e43d60bd03a3..f437d66cddb54 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
@@ -178,6 +178,14 @@ sealed trait Vector extends Serializable {
*/
@Since("2.0.0")
def argmax: Int
+
+ /**
+ * Calculate the dot product of this vector with another.
+ *
+ * If `size` does not match an [[IllegalArgumentException]] is thrown.
+ */
+ @Since("3.0.0")
+ def dot(v: Vector): Double = BLAS.dot(this, v)
}
/**
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala
index 332734bd28341..7d29d6dcea908 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala
@@ -21,7 +21,7 @@ import java.util.Random
import breeze.linalg.{CSCMatrix, Matrix => BM}
import org.mockito.Mockito.when
-import org.scalatest.mockito.MockitoSugar._
+import org.scalatestplus.mockito.MockitoSugar._
import scala.collection.mutable.{Map => MutableMap}
import org.apache.spark.ml.SparkMLFunSuite
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
index 0a316f57f811b..c97dc2c3c06f8 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
@@ -380,4 +380,27 @@ class VectorsSuite extends SparkMLFunSuite {
Vectors.sparse(-1, Array((1, 2.0)))
}
}
+
+ test("dot product only supports vectors of same size") {
+ val vSize4 = Vectors.dense(arr)
+ val vSize1 = Vectors.zeros(1)
+ intercept[IllegalArgumentException]{ vSize1.dot(vSize4) }
+ }
+
+ test("dense vector dot product") {
+ val dv = Vectors.dense(arr)
+ assert(dv.dot(dv) === 0.26)
+ }
+
+ test("sparse vector dot product") {
+ val sv = Vectors.sparse(n, indices, values)
+ assert(sv.dot(sv) === 0.26)
+ }
+
+ test("mixed sparse and dense vector dot product") {
+ val sv = Vectors.sparse(n, indices, values)
+ val dv = Vectors.dense(arr)
+ assert(sv.dot(dv) === 0.26)
+ assert(dv.dot(sv) === 0.26)
+ }
}
diff --git a/mllib/benchmarks/UDTSerializationBenchmark-jdk11-results.txt b/mllib/benchmarks/UDTSerializationBenchmark-jdk11-results.txt
new file mode 100644
index 0000000000000..6f671405b8343
--- /dev/null
+++ b/mllib/benchmarks/UDTSerializationBenchmark-jdk11-results.txt
@@ -0,0 +1,12 @@
+================================================================================================
+VectorUDT de/serialization
+================================================================================================
+
+OpenJDK 64-Bit Server VM 11.0.4+11-LTS on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+VectorUDT de/serialization: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+serialize 269 292 13 0.0 269441.1 1.0X
+deserialize 164 191 9 0.0 164314.6 1.6X
+
+
diff --git a/mllib/benchmarks/UDTSerializationBenchmark-results.txt b/mllib/benchmarks/UDTSerializationBenchmark-results.txt
index 169f4c60c748e..a0c853e99014b 100644
--- a/mllib/benchmarks/UDTSerializationBenchmark-results.txt
+++ b/mllib/benchmarks/UDTSerializationBenchmark-results.txt
@@ -2,12 +2,11 @@
VectorUDT de/serialization
================================================================================================
-Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.13.6
-Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
-
-VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-serialize 144 / 206 0.0 143979.7 1.0X
-deserialize 114 / 135 0.0 113802.6 1.3X
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+VectorUDT de/serialization: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+serialize 271 294 12 0.0 271054.3 1.0X
+deserialize 190 192 2 0.0 189706.1 1.4X
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 58815434cbdaf..9eac8ed22a3f6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml
import org.apache.spark.annotation.{DeveloperApi, Since}
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -62,6 +62,39 @@ private[ml] trait PredictorParams extends Params
}
SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
}
+
+ /**
+ * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
+ * and put it in an RDD with strong types.
+ */
+ protected def extractInstances(dataset: Dataset[_]): RDD[Instance] = {
+ val w = this match {
+ case p: HasWeightCol =>
+ if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
+ col($(p.weightCol)).cast(DoubleType)
+ } else {
+ lit(1.0)
+ }
+ }
+
+ dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ Instance(label, weight, features)
+ }
+ }
+
+ /**
+ * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
+ * and put it in an RDD with strong types.
+ * Validate the output instances with the given function.
+ */
+ protected def extractInstances(dataset: Dataset[_],
+ validateInstance: Instance => Unit): RDD[Instance] = {
+ extractInstances(dataset).map { instance =>
+ validateInstance(instance)
+ instance
+ }
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index b6b02e77909bd..9ac673078d4ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.{MetadataUtils, SchemaUtils}
@@ -42,6 +42,22 @@ private[spark] trait ClassifierParams
val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT)
}
+
+ /**
+ * Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
+ * and put it in an RDD with strong types.
+ * Validates the label on the classifier is a valid integer in the range [0, numClasses).
+ */
+ protected def extractInstances(dataset: Dataset[_],
+ numClasses: Int): RDD[Instance] = {
+ val validateInstance = (instance: Instance) => {
+ val label = instance.label
+ require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
+ s" dataset with invalid label $label. Labels must be integers in range" +
+ s" [0, $numClasses).")
+ }
+ extractInstances(dataset, validateInstance)
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 6bd8a26f5b1a8..2d0212f36fad4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -22,7 +22,7 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.feature.{Instance, LabeledPoint}
+import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
@@ -34,9 +34,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
-import org.apache.spark.sql.functions.{col, lit, udf}
-import org.apache.spark.sql.types.DoubleType
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.functions.{col, udf}
/**
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
@@ -116,9 +115,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr =>
instr.logPipelineStage(this)
instr.logDataset(dataset)
- val categoricalFeatures: Map[Int, Int] =
- MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val numClasses: Int = getNumClasses(dataset)
+ val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val numClasses = getNumClasses(dataset)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
@@ -126,13 +124,7 @@ class DecisionTreeClassifier @Since("1.4.0") (
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
validateNumClasses(numClasses)
- val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
- val instances =
- dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
- case Row(label: Double, weight: Double, features: Vector) =>
- validateLabel(label, numClasses)
- Instance(label, weight, features)
- }
+ val instances = extractInstances(dataset, numClasses)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
instr.logNumClasses(numClasses)
instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index 78503585261bf..e467228b4cc14 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -36,9 +36,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
-import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, Row}
-import org.apache.spark.sql.functions.{col, lit}
/** Params for linear SVM Classifier. */
private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam
@@ -161,12 +159,7 @@ class LinearSVC @Since("2.2.0") (
override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra)
override protected def train(dataset: Dataset[_]): LinearSVCModel = instrumented { instr =>
- val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
- val instances: RDD[Instance] =
- dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
- case Row(label: Double, weight: Double, features: Vector) =>
- Instance(label, weight, features)
- }
+ val instances = extractInstances(dataset)
instr.logPipelineStage(this)
instr.logDataset(dataset)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 0997c1e7b38d6..af6e2b39ecb60 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -40,9 +40,8 @@ import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, Multiclas
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
-import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils
@@ -492,12 +491,7 @@ class LogisticRegression @Since("1.2.0") (
protected[spark] def train(
dataset: Dataset[_],
handlePersistence: Boolean): LogisticRegressionModel = instrumented { instr =>
- val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
- val instances: RDD[Instance] =
- dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
- case Row(label: Double, weight: Double, features: Vector) =>
- Instance(label, weight, features)
- }
+ val instances = extractInstances(dataset)
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 47b8a8df637b9..41db6f3f44342 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
import org.apache.spark.ml.feature.OneHotEncoderModel
-import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index e97af0582d358..205f565aa2685 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
import org.apache.spark.ml.PredictorParams
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.HasWeightCol
@@ -28,7 +29,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.{Dataset, Row}
-import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.functions.col
/**
* Params for Naive Bayes Classifiers.
@@ -137,17 +138,14 @@ class NaiveBayes @Since("1.5.0") (
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
- val modelTypeValue = $(modelType)
- val requireValues: Vector => Unit = {
- modelTypeValue match {
- case Multinomial =>
- requireNonnegativeValues
- case Bernoulli =>
- requireZeroOneBernoulliValues
- case _ =>
- // This should never happen.
- throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
- }
+ val validateInstance = $(modelType) match {
+ case Multinomial =>
+ (instance: Instance) => requireNonnegativeValues(instance.features)
+ case Bernoulli =>
+ (instance: Instance) => requireZeroOneBernoulliValues(instance.features)
+ case _ =>
+ // This should never happen.
+ throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
}
instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol,
@@ -155,17 +153,15 @@ class NaiveBayes @Since("1.5.0") (
val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size
instr.logNumFeatures(numFeatures)
- val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
// Aggregates term frequencies per label.
// TODO: Calling aggregateByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
- val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
- .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2)))
- }.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))(
+ val aggregated = extractInstances(dataset, validateInstance).map { instance =>
+ (instance.label, (instance.weight, instance.features))
+ }.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))(
seqOp = {
case ((weightSum, featureSum, count), (weight, features)) =>
- requireValues(features)
BLAS.axpy(weight, features, featureSum)
(weightSum + weight, featureSum, count + 1)
},
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 86caa1247e77f..979eb5e5448a8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -33,8 +33,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix,
Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession}
-import org.apache.spark.sql.functions.udf
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
+import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
@@ -111,28 +111,32 @@ class GaussianMixtureModel private[ml] (
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
- var predictionColNames = Seq.empty[String]
- var predictionColumns = Seq.empty[Column]
-
- if ($(predictionCol).nonEmpty) {
- val predUDF = udf((vector: Vector) => predict(vector))
- predictionColNames :+= $(predictionCol)
- predictionColumns :+= predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))
- }
+ val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol))
+ var outputData = dataset
+ var numColsOutput = 0
if ($(probabilityCol).nonEmpty) {
val probUDF = udf((vector: Vector) => predictProbability(vector))
- predictionColNames :+= $(probabilityCol)
- predictionColumns :+= probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))
+ outputData = outputData.withColumn($(probabilityCol), probUDF(vectorCol))
+ numColsOutput += 1
+ }
+
+ if ($(predictionCol).nonEmpty) {
+ if ($(probabilityCol).nonEmpty) {
+ val predUDF = udf((vector: Vector) => vector.argmax)
+ outputData = outputData.withColumn($(predictionCol), predUDF(col($(probabilityCol))))
+ } else {
+ val predUDF = udf((vector: Vector) => predict(vector))
+ outputData = outputData.withColumn($(predictionCol), predUDF(vectorCol))
+ }
+ numColsOutput += 1
}
- if (predictionColNames.nonEmpty) {
- dataset.withColumns(predictionColNames, predictionColumns)
- } else {
+ if (numColsOutput == 0) {
this.logWarning(s"$uid: GaussianMixtureModel.transform() does nothing" +
" because no output columns were set.")
- dataset.toDF()
}
+ outputData.toDF
}
@Since("2.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 2a7b3c579b078..09e8e7b232f3a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -59,6 +59,28 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
@Since("1.2.0")
def setMetricName(value: String): this.type = set(metricName, value)
+ /**
+ * param for number of bins to down-sample the curves (ROC curve, PR curve) in area
+ * computation. If 0, no down-sampling will occur.
+ * Default: 1000.
+ * @group expertParam
+ */
+ @Since("3.0.0")
+ val numBins: IntParam = new IntParam(this, "numBins", "Number of bins to down-sample " +
+ "the curves (ROC curve, PR curve) in area computation. If 0, no down-sampling will occur. " +
+ "Must be >= 0.",
+ ParamValidators.gtEq(0))
+
+ /** @group expertGetParam */
+ @Since("3.0.0")
+ def getNumBins: Int = $(numBins)
+
+ /** @group expertSetParam */
+ @Since("3.0.0")
+ def setNumBins(value: Int): this.type = set(numBins, value)
+
+ setDefault(numBins -> 1000)
+
/** @group setParam */
@Since("1.5.0")
def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
@@ -94,7 +116,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
case Row(rawPrediction: Double, label: Double, weight: Double) =>
(rawPrediction, label, weight)
}
- val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights)
+ val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights, $(numBins))
val metric = $(metricName) match {
case "areaUnderROC" => metrics.areaUnderROC()
case "areaUnderPR" => metrics.areaUnderPR()
@@ -104,10 +126,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
}
@Since("1.5.0")
- override def isLargerBetter: Boolean = $(metricName) match {
- case "areaUnderROC" => true
- case "areaUnderPR" => true
- }
+ override def isLargerBetter: Boolean = true
@Since("1.4.1")
override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index dd667a85fa598..b0cafefe420a3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
+import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
@@ -43,13 +43,14 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
* - `"mse"`: mean squared error
* - `"r2"`: R^2^ metric
* - `"mae"`: mean absolute error
+ * - `"var"`: explained variance
*
* @group param
*/
@Since("1.4.0")
val metricName: Param[String] = {
- val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae"))
- new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae)", allowedParams)
+ val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae", "var"))
+ new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae|var)", allowedParams)
}
/** @group getParam */
@@ -60,6 +61,25 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("1.4.0")
def setMetricName(value: String): this.type = set(metricName, value)
+ /**
+ * param for whether the regression is through the origin.
+ * Default: false.
+ * @group expertParam
+ */
+ @Since("3.0.0")
+ val throughOrigin: BooleanParam = new BooleanParam(this, "throughOrigin",
+ "Whether the regression is through the origin.")
+
+ /** @group expertGetParam */
+ @Since("3.0.0")
+ def getThroughOrigin: Boolean = $(throughOrigin)
+
+ /** @group expertSetParam */
+ @Since("3.0.0")
+ def setThroughOrigin(value: Boolean): this.type = set(throughOrigin, value)
+
+ setDefault(throughOrigin -> false)
+
/** @group setParam */
@Since("1.4.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
@@ -86,22 +106,20 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
.rdd
.map { case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight) }
- val metrics = new RegressionMetrics(predictionAndLabelsWithWeights)
- val metric = $(metricName) match {
+ val metrics = new RegressionMetrics(predictionAndLabelsWithWeights, $(throughOrigin))
+ $(metricName) match {
case "rmse" => metrics.rootMeanSquaredError
case "mse" => metrics.meanSquaredError
case "r2" => metrics.r2
case "mae" => metrics.meanAbsoluteError
+ case "var" => metrics.explainedVariance
}
- metric
}
@Since("1.4.0")
override def isLargerBetter: Boolean = $(metricName) match {
- case "rmse" => false
- case "mse" => false
- case "r2" => true
- case "mae" => false
+ case "r2" | "var" => true
+ case _ => false
}
@Since("1.5.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index 2b0862c60fdf7..c4daf64dfc5f0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -75,30 +75,40 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
val schema = dataset.schema
val inputType = schema($(inputCol)).dataType
val td = $(threshold)
+ val metadata = outputSchema($(outputCol)).metadata
- val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 }
- val binarizerVector = udf { (data: Vector) =>
- val indices = ArrayBuilder.make[Int]
- val values = ArrayBuilder.make[Double]
-
- data.foreachActive { (index, value) =>
- if (value > td) {
- indices += index
- values += 1.0
+ val binarizerUDF = inputType match {
+ case DoubleType =>
+ udf { in: Double => if (in > td) 1.0 else 0.0 }
+
+ case _: VectorUDT if td >= 0 =>
+ udf { vector: Vector =>
+ val indices = ArrayBuilder.make[Int]
+ val values = ArrayBuilder.make[Double]
+ vector.foreachActive { (index, value) =>
+ if (value > td) {
+ indices += index
+ values += 1.0
+ }
+ }
+ Vectors.sparse(vector.size, indices.result(), values.result()).compressed
}
- }
- Vectors.sparse(data.size, indices.result(), values.result()).compressed
+ case _: VectorUDT if td < 0 =>
+ this.logWarning(s"Binarization operations on sparse dataset with negative threshold " +
+ s"$td will build a dense output, so take care when applying to sparse input.")
+ udf { vector: Vector =>
+ val values = Array.fill(vector.size)(1.0)
+ vector.foreachActive { (index, value) =>
+ if (value <= td) {
+ values(index) = 0.0
+ }
+ }
+ Vectors.dense(values).compressed
+ }
}
- val metadata = outputSchema($(outputCol)).metadata
-
- inputType match {
- case DoubleType =>
- dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata))
- case _: VectorUDT =>
- dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata))
- }
+ dataset.withColumn($(outputCol), binarizerUDF(col($(inputCol))), metadata)
}
@Since("1.4.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
index 32d98151bdcff..84d6a536ccca8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.feature
-import edu.emory.mathcs.jtransforms.dct._
+import org.jtransforms.dct._
import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 5bfaa3b7f3f52..f7a83cdd41a90 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -167,25 +167,38 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
@Since("2.3.0")
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
- private[feature] def getInOutCols: (Array[String], Array[String]) = {
- require((isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) ||
- (!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)),
- "QuantileDiscretizer only supports setting either inputCol/outputCol or" +
- "inputCols/outputCols."
- )
+ @Since("1.6.0")
+ override def transformSchema(schema: StructType): StructType = {
+ ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol),
+ Seq(outputCols))
if (isSet(inputCol)) {
- (Array($(inputCol)), Array($(outputCol)))
- } else {
- require($(inputCols).length == $(outputCols).length,
- "inputCols number do not match outputCols")
- ($(inputCols), $(outputCols))
+ require(!isSet(numBucketsArray),
+ s"numBucketsArray can't be set for single-column QuantileDiscretizer.")
}
- }
- @Since("1.6.0")
- override def transformSchema(schema: StructType): StructType = {
- val (inputColNames, outputColNames) = getInOutCols
+ if (isSet(inputCols)) {
+ require(getInputCols.length == getOutputCols.length,
+ s"QuantileDiscretizer $this has mismatched Params " +
+ s"for multi-column transform. Params (inputCols, outputCols) should have " +
+ s"equal lengths, but they have different lengths: " +
+ s"(${getInputCols.length}, ${getOutputCols.length}).")
+ if (isSet(numBucketsArray)) {
+ require(getInputCols.length == getNumBucketsArray.length,
+ s"QuantileDiscretizer $this has mismatched Params " +
+ s"for multi-column transform. Params (inputCols, outputCols, numBucketsArray) " +
+ s"should have equal lengths, but they have different lengths: " +
+ s"(${getInputCols.length}, ${getOutputCols.length}, ${getNumBucketsArray.length}).")
+ require(!isSet(numBuckets),
+ s"exactly one of numBuckets, numBucketsArray Params to be set, but both are set." )
+ }
+ }
+
+ val (inputColNames, outputColNames) = if (isSet(inputCols)) {
+ ($(inputCols).toSeq, $(outputCols).toSeq)
+ } else {
+ (Seq($(inputCol)), Seq($(outputCol)))
+ }
val existingFields = schema.fields
var outputFields = existingFields
inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 6c0d5fc70ab4e..df7d17059980b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -392,7 +392,7 @@ class RFormulaModel private[feature](
}
}
- private def checkCanTransform(schema: StructType) {
+ private def checkCanTransform(schema: StructType): Unit = {
val columnNames = schema.map(_.name)
require(!columnNames.contains($(featuresCol)), "Features column already exists.")
require(
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index fb7334d41ba44..bf6e8ec8f37b8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -31,7 +31,7 @@ import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
-import org.apache.spark.{Dependency, Partitioner, ShuffleDependency, SparkContext}
+import org.apache.spark.{Dependency, Partitioner, ShuffleDependency, SparkContext, SparkException}
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
@@ -42,7 +42,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.linalg.CholeskyDecomposition
import org.apache.spark.mllib.optimization.NNLS
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{DeterministicLevel, RDD}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -564,6 +564,13 @@ object ALSModel extends MLReadable[ALSModel] {
* r is greater than 0 and 0 if r is less than or equal to 0. The ratings then act as 'confidence'
* values related to strength of indicated user
* preferences rather than explicit ratings given to items.
+ *
+ * Note: the input rating dataset to the ALS implementation should be deterministic.
+ * Nondeterministic data can cause failure during fitting ALS model.
+ * For example, an order-sensitive operation like sampling after a repartition makes dataset
+ * output nondeterministic, like `dataset.repartition(2).sample(false, 0.5, 1618)`.
+ * Checkpointing sampled dataset or adding a sort before sampling can help make the dataset
+ * deterministic.
*/
@Since("1.3.0")
class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] with ALSParams
@@ -794,7 +801,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* Given a triangular matrix in the order of fillXtX above, compute the full symmetric square
* matrix that it represents, storing it into destMatrix.
*/
- private def fillAtA(triAtA: Array[Double], lambda: Double) {
+ private def fillAtA(triAtA: Array[Double], lambda: Double): Unit = {
var i = 0
var pos = 0
var a = 0.0
@@ -1666,6 +1673,13 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
}
val merged = srcOut.groupByKey(new ALSPartitioner(dstInBlocks.partitions.length))
+
+ // SPARK-28927: Nondeterministic RDDs causes inconsistent in/out blocks in case of rerun.
+ // It can cause runtime error when matching in/out user/item blocks.
+ val isBlockRDDNondeterministic =
+ dstInBlocks.outputDeterministicLevel == DeterministicLevel.INDETERMINATE ||
+ srcOutBlocks.outputDeterministicLevel == DeterministicLevel.INDETERMINATE
+
dstInBlocks.join(merged).mapValues {
case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) =>
val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks)
@@ -1686,7 +1700,19 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
val encoded = srcEncodedIndices(i)
val blockId = srcEncoder.blockId(encoded)
val localIndex = srcEncoder.localIndex(encoded)
- val srcFactor = sortedSrcFactors(blockId)(localIndex)
+ var srcFactor: Array[Float] = null
+ try {
+ srcFactor = sortedSrcFactors(blockId)(localIndex)
+ } catch {
+ case a: ArrayIndexOutOfBoundsException if isBlockRDDNondeterministic =>
+ val errMsg = "A failure detected when matching In/Out blocks of users/items. " +
+ "Because at least one In/Out block RDD is found to be nondeterministic now, " +
+ "the issue is probably caused by nondeterministic input data. You can try to " +
+ "checkpoint training data to make it deterministic. If you do `repartition` + " +
+ "`sample` or `randomSplit`, you can also try to sort it before `sample` or " +
+ "`randomSplit` to make it deterministic."
+ throw new SparkException(errMsg, a)
+ }
val rating = ratings(i)
if (implicitPrefs) {
// Extension to the original paper to handle rating < 0. confidence is a function
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 106be1b78af47..602b5fac20d3b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -23,7 +23,7 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.feature.{Instance, LabeledPoint}
+import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
@@ -34,9 +34,8 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
+import org.apache.spark.sql.{Column, DataFrame, Dataset}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.DoubleType
/**
@@ -118,12 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr =>
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
- val instances =
- dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
- case Row(label: Double, weight: Double, features: Vector) =>
- Instance(label, weight, features)
- }
+ val instances = extractInstances(dataset)
val strategy = getOldStrategy(categoricalFeatures)
instr.logPipelineStage(this)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index a226ca49e6deb..4dc0c247ce331 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -1036,31 +1036,33 @@ class GeneralizedLinearRegressionModel private[ml] (
}
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
- var predictionColNames = Seq.empty[String]
- var predictionColumns = Seq.empty[Column]
-
val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType)
+ var outputData = dataset
+ var numColsOutput = 0
- if ($(predictionCol).nonEmpty) {
- val predictUDF = udf { (features: Vector, offset: Double) => predict(features, offset) }
- predictionColNames :+= $(predictionCol)
- predictionColumns :+= predictUDF(col($(featuresCol)), offset)
+ if (hasLinkPredictionCol) {
+ val predLinkUDF = udf((features: Vector, offset: Double) => predictLink(features, offset))
+ outputData = outputData
+ .withColumn($(linkPredictionCol), predLinkUDF(col($(featuresCol)), offset))
+ numColsOutput += 1
}
- if (hasLinkPredictionCol) {
- val predictLinkUDF =
- udf { (features: Vector, offset: Double) => predictLink(features, offset) }
- predictionColNames :+= $(linkPredictionCol)
- predictionColumns :+= predictLinkUDF(col($(featuresCol)), offset)
+ if ($(predictionCol).nonEmpty) {
+ if (hasLinkPredictionCol) {
+ val predUDF = udf((eta: Double) => familyAndLink.fitted(eta))
+ outputData = outputData.withColumn($(predictionCol), predUDF(col($(linkPredictionCol))))
+ } else {
+ val predUDF = udf((features: Vector, offset: Double) => predict(features, offset))
+ outputData = outputData.withColumn($(predictionCol), predUDF(col($(featuresCol)), offset))
+ }
+ numColsOutput += 1
}
- if (predictionColNames.nonEmpty) {
- dataset.withColumns(predictionColNames, predictionColumns)
- } else {
+ if (numColsOutput == 0) {
this.logWarning(s"$uid: GeneralizedLinearRegressionModel.transform() does nothing" +
" because no output columns were set.")
- dataset.toDF()
}
+ outputData.toDF
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index abf75d70ea028..4c600eac26b37 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -43,7 +43,6 @@ import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel}
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
@@ -320,13 +319,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
override protected def train(dataset: Dataset[_]): LinearRegressionModel = instrumented { instr =>
// Extract the number of features before deciding optimization solver.
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
- val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
- val instances: RDD[Instance] = dataset.select(
- col($(labelCol)), w, col($(featuresCol))).rdd.map {
- case Row(label: Double, weight: Double, features: Vector) =>
- Instance(label, weight, features)
- }
+ val instances = extractInstances(dataset)
instr.logPipelineStage(this)
instr.logDataset(dataset)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
index c0a1683d3cb6f..314cf422be87e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
@@ -28,7 +28,8 @@ import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
* @tparam Learner Concrete Estimator type
* @tparam M Concrete Model type
*/
-private[spark] abstract class Regressor[
+@DeveloperApi
+abstract class Regressor[
FeaturesType,
Learner <: Regressor[FeaturesType, Learner, M],
M <: RegressionModel[FeaturesType, M]]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index 8f8a17171f980..6c194902a750b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -90,7 +90,7 @@ private[spark] class DecisionTreeMetadata(
* Set number of splits for a continuous feature.
* For a continuous feature, number of bins is number of splits plus 1.
*/
- def setNumSplits(featureIndex: Int, numSplits: Int) {
+ def setNumSplits(featureIndex: Int, numSplits: Int): Unit = {
require(isContinuous(featureIndex),
s"Only number of bin for a continuous feature can be set.")
numBins(featureIndex) = numSplits + 1
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
index 8cd4a7ca9493b..58a763257af20 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
@@ -205,21 +205,21 @@ private[spark] class OptionalInstrumentation private(
protected override def logName: String = className
- override def logInfo(msg: => String) {
+ override def logInfo(msg: => String): Unit = {
instrumentation match {
case Some(instr) => instr.logInfo(msg)
case None => super.logInfo(msg)
}
}
- override def logWarning(msg: => String) {
+ override def logWarning(msg: => String): Unit = {
instrumentation match {
case Some(instr) => instr.logWarning(msg)
case None => super.logWarning(msg)
}
}
- override def logError(msg: => String) {
+ override def logError(msg: => String): Unit = {
instrumentation match {
case Some(instr) => instr.logError(msg)
case None => super.logError(msg)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 4617073f9decd..bafaafb720ed8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -347,7 +347,6 @@ private[python] class PythonMLLibAPI extends Serializable {
data: JavaRDD[Vector],
k: Int,
maxIterations: Int,
- runs: Int,
initializationMode: String,
seed: java.lang.Long,
initializationSteps: Int,
@@ -1312,7 +1311,7 @@ private[spark] abstract class SerDeBase {
}
}
- private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler)
+ private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit
}
def dumps(obj: AnyRef): Array[Byte] = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index d86aa01c9195a..df888bc3d5d51 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -224,117 +224,11 @@ class LogisticRegressionWithSGD private[mllib] (
.setMiniBatchFraction(miniBatchFraction)
override protected val validators = List(DataValidators.binaryLabelValidator)
- /**
- * Construct a LogisticRegression object with default parameters: {stepSize: 1.0,
- * numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}.
- */
- @Since("0.8.0")
- @deprecated("Use ml.classification.LogisticRegression or LogisticRegressionWithLBFGS", "2.0.0")
- def this() = this(1.0, 100, 0.01, 1.0)
-
override protected[mllib] def createModel(weights: Vector, intercept: Double) = {
new LogisticRegressionModel(weights, intercept)
}
}
-/**
- * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent.
- *
- * @note Labels used in Logistic Regression should be {0, 1}
- */
-@Since("0.8.0")
-@deprecated("Use ml.classification.LogisticRegression or LogisticRegressionWithLBFGS", "2.0.0")
-object LogisticRegressionWithSGD {
- // NOTE(shivaram): We use multiple train methods instead of default arguments to support
- // Java programs.
-
- /**
- * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed
- * number of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in
- * gradient descent are initialized using the initial weights provided.
- *
- * @param input RDD of (label, array of features) pairs.
- * @param numIterations Number of iterations of gradient descent to run.
- * @param stepSize Step size to be used for each iteration of gradient descent.
- * @param miniBatchFraction Fraction of data to be used per iteration.
- * @param initialWeights Initial set of weights to be used. Array should be equal in size to
- * the number of features in the data.
- *
- * @note Labels used in Logistic Regression should be {0, 1}
- */
- @Since("1.0.0")
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double,
- miniBatchFraction: Double,
- initialWeights: Vector): LogisticRegressionModel = {
- new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction)
- .run(input, initialWeights)
- }
-
- /**
- * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed
- * number of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate the gradient.
- *
- * @param input RDD of (label, array of features) pairs.
- * @param numIterations Number of iterations of gradient descent to run.
- * @param stepSize Step size to be used for each iteration of gradient descent.
- * @param miniBatchFraction Fraction of data to be used per iteration.
- *
- * @note Labels used in Logistic Regression should be {0, 1}
- */
- @Since("1.0.0")
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double,
- miniBatchFraction: Double): LogisticRegressionModel = {
- new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction)
- .run(input)
- }
-
- /**
- * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed
- * number of iterations of gradient descent using the specified step size. We use the entire data
- * set to update the gradient in each iteration.
- *
- * @param input RDD of (label, array of features) pairs.
- * @param stepSize Step size to be used for each iteration of Gradient Descent.
- * @param numIterations Number of iterations of gradient descent to run.
- * @return a LogisticRegressionModel which has the weights and offset from training.
- *
- * @note Labels used in Logistic Regression should be {0, 1}
- */
- @Since("1.0.0")
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double): LogisticRegressionModel = {
- train(input, numIterations, stepSize, 1.0)
- }
-
- /**
- * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed
- * number of iterations of gradient descent using a step size of 1.0. We use the entire data set
- * to update the gradient in each iteration.
- *
- * @param input RDD of (label, array of features) pairs.
- * @param numIterations Number of iterations of gradient descent to run.
- * @return a LogisticRegressionModel which has the weights and offset from training.
- *
- * @note Labels used in Logistic Regression should be {0, 1}
- */
- @Since("1.0.0")
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int): LogisticRegressionModel = {
- train(input, numIterations, 1.0, 1.0)
- }
-}
-
/**
* Train a classification model for Multinomial/Binary Logistic Regression using
* Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 4bb79bc69eef4..278d61d916735 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -479,58 +479,6 @@ object KMeans {
.run(data)
}
- /**
- * Trains a k-means model using the given set of parameters.
- *
- * @param data Training points as an `RDD` of `Vector` types.
- * @param k Number of clusters to create.
- * @param maxIterations Maximum number of iterations allowed.
- * @param runs This param has no effect since Spark 2.0.0.
- * @param initializationMode The initialization algorithm. This can either be "random" or
- * "k-means||". (default: "k-means||")
- * @param seed Random seed for cluster initialization. Default is to generate seed based
- * on system time.
- */
- @Since("1.3.0")
- @deprecated("Use train method without 'runs'", "2.1.0")
- def train(
- data: RDD[Vector],
- k: Int,
- maxIterations: Int,
- runs: Int,
- initializationMode: String,
- seed: Long): KMeansModel = {
- new KMeans().setK(k)
- .setMaxIterations(maxIterations)
- .setInitializationMode(initializationMode)
- .setSeed(seed)
- .run(data)
- }
-
- /**
- * Trains a k-means model using the given set of parameters.
- *
- * @param data Training points as an `RDD` of `Vector` types.
- * @param k Number of clusters to create.
- * @param maxIterations Maximum number of iterations allowed.
- * @param runs This param has no effect since Spark 2.0.0.
- * @param initializationMode The initialization algorithm. This can either be "random" or
- * "k-means||". (default: "k-means||")
- */
- @Since("0.8.0")
- @deprecated("Use train method without 'runs'", "2.1.0")
- def train(
- data: RDD[Vector],
- k: Int,
- maxIterations: Int,
- runs: Int,
- initializationMode: String): KMeansModel = {
- new KMeans().setK(k)
- .setMaxIterations(maxIterations)
- .setInitializationMode(initializationMode)
- .run(data)
- }
-
/**
* Trains a k-means model using specified parameters and the default values for unspecified.
*/
@@ -544,21 +492,6 @@ object KMeans {
.run(data)
}
- /**
- * Trains a k-means model using specified parameters and the default values for unspecified.
- */
- @Since("0.8.0")
- @deprecated("Use train method without 'runs'", "2.1.0")
- def train(
- data: RDD[Vector],
- k: Int,
- maxIterations: Int,
- runs: Int): KMeansModel = {
- new KMeans().setK(k)
- .setMaxIterations(maxIterations)
- .run(data)
- }
-
private[spark] def validateInitMode(initMode: String): Boolean = {
initMode match {
case KMeans.RANDOM => true
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index ff4ca0ac40fe2..c7d44e8752cd9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -269,7 +269,7 @@ class StreamingKMeans @Since("1.2.0") (
* @param data DStream containing vector data
*/
@Since("1.2.0")
- def trainOn(data: DStream[Vector]) {
+ def trainOn(data: DStream[Vector]): Unit = {
assertInitialized()
data.foreachRDD { (rdd, time) =>
model = model.update(rdd, decayFactor, timeUnit)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
index d34a7ca6c9c7f..f4e2040569f48 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
@@ -81,7 +81,7 @@ class BinaryClassificationMetrics @Since("3.0.0") (
* Unpersist intermediate RDDs used in the computation.
*/
@Since("1.0.0")
- def unpersist() {
+ def unpersist(): Unit = {
cumulativeCounts.unpersist()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 82f5b279846ba..b771e077b02ac 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -44,17 +44,6 @@ class ChiSqSelectorModel @Since("1.3.0") (
private val filterIndices = selectedFeatures.sorted
- @deprecated("not intended for subclasses to use", "2.1.0")
- protected def isSorted(array: Array[Int]): Boolean = {
- var i = 1
- val len = array.length
- while (i < len) {
- if (array(i) < array(i-1)) return false
- i += 1
- }
- true
- }
-
/**
* Applies transformation on a vector.
*
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index cb97742245689..1f5558dc2a50e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -285,7 +285,7 @@ private[spark] object BLAS extends Serializable with Logging {
* @param x the vector x that contains the n elements.
* @param A the symmetric matrix A. Size of n x n.
*/
- def syr(alpha: Double, x: Vector, A: DenseMatrix) {
+ def syr(alpha: Double, x: Vector, A: DenseMatrix): Unit = {
val mA = A.numRows
val nA = A.numCols
require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA")
@@ -299,7 +299,7 @@ private[spark] object BLAS extends Serializable with Logging {
}
}
- private def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
+ private def syr(alpha: Double, x: DenseVector, A: DenseMatrix): Unit = {
val nA = A.numRows
val mA = A.numCols
@@ -317,7 +317,7 @@ private[spark] object BLAS extends Serializable with Logging {
}
}
- private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) {
+ private def syr(alpha: Double, x: SparseVector, A: DenseMatrix): Unit = {
val mA = A.numCols
val xIndices = x.indices
val xValues = x.values
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index e474cfa002fad..0304fd88dcd9f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -155,7 +155,7 @@ sealed trait Matrix extends Serializable {
* and column indices respectively with the type `Int`, and the final parameter is the
* corresponding value in the matrix with type `Double`.
*/
- private[spark] def foreachActive(f: (Int, Int, Double) => Unit)
+ private[spark] def foreachActive(f: (Int, Int, Double) => Unit): Unit
/**
* Find the number of non-zero active values.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index b754fad0c1796..83a519326df75 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -204,6 +204,14 @@ sealed trait Vector extends Serializable {
*/
@Since("2.0.0")
def asML: newlinalg.Vector
+
+ /**
+ * Calculate the dot product of this vector with another.
+ *
+ * If `size` does not match an [[IllegalArgumentException]] is thrown.
+ */
+ @Since("3.0.0")
+ def dot(v: Vector): Double = BLAS.dot(this, v)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
index 0d223de9b6f7e..f3b984948e483 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
@@ -153,7 +153,7 @@ class CoordinateMatrix @Since("1.0.0") (
}
/** Determines the size by computing the max row/column index. */
- private def computeSize() {
+ private def computeSize(): Unit = {
// Reduce will throw an exception if `entries` is empty.
val (m1, n1) = entries.map(entry => (entry.i, entry.j)).reduce { case ((i1, j1), (i2, j2)) =>
(math.max(i1, i2), math.max(j1, j2))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index 43f48befd014f..f25d86b30631a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -770,7 +770,7 @@ class RowMatrix @Since("1.0.0") (
}
/** Updates or verifies the number of rows. */
- private def updateNumRows(m: Long) {
+ private def updateNumRows(m: Long): Unit = {
if (nRows <= 0) {
nRows = m
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
index fa04f8eb5e796..d3b548832bb21 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
@@ -107,7 +107,7 @@ class PoissonGenerator @Since("1.1.0") (
override def nextValue(): Double = rng.sample()
@Since("1.1.0")
- override def setSeed(seed: Long) {
+ override def setSeed(seed: Long): Unit = {
rng.reseedRandomGenerator(seed)
}
@@ -132,7 +132,7 @@ class ExponentialGenerator @Since("1.3.0") (
override def nextValue(): Double = rng.sample()
@Since("1.3.0")
- override def setSeed(seed: Long) {
+ override def setSeed(seed: Long): Unit = {
rng.reseedRandomGenerator(seed)
}
@@ -159,7 +159,7 @@ class GammaGenerator @Since("1.3.0") (
override def nextValue(): Double = rng.sample()
@Since("1.3.0")
- override def setSeed(seed: Long) {
+ override def setSeed(seed: Long): Unit = {
rng.reseedRandomGenerator(seed)
}
@@ -187,7 +187,7 @@ class LogNormalGenerator @Since("1.3.0") (
override def nextValue(): Double = rng.sample()
@Since("1.3.0")
- override def setSeed(seed: Long) {
+ override def setSeed(seed: Long): Unit = {
rng.reseedRandomGenerator(seed)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 12870f819b147..f3f15ba0d0f2c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -62,6 +62,13 @@ case class Rating @Since("0.8.0") (
* r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of
* indicated user
* preferences rather than explicit ratings given to items.
+ *
+ * Note: the input rating RDD to the ALS implementation should be deterministic.
+ * Nondeterministic data can cause failure during fitting ALS model.
+ * For example, an order-sensitive operation like sampling after a repartition makes RDD
+ * output nondeterministic, like `rdd.repartition(2).sample(false, 0.5, 1618)`.
+ * Checkpointing sampled RDD or adding a sort before sampling can help make the RDD
+ * deterministic.
*/
@Since("0.8.0")
class ALS private (
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index ead9f5b300375..47bb1fa9127a6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -24,7 +24,6 @@ import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression.impl.GLMRegressionModel
import org.apache.spark.mllib.util.{Loader, Saveable}
-import org.apache.spark.rdd.RDD
/**
* Regression model trained using Lasso.
@@ -99,117 +98,7 @@ class LassoWithSGD private[mllib] (
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
- /**
- * Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100,
- * regParam: 0.01, miniBatchFraction: 1.0}.
- */
- @Since("0.8.0")
- @deprecated("Use ml.regression.LinearRegression with elasticNetParam = 1.0. Note the default " +
- "regParam is 0.01 for LassoWithSGD, but is 0.0 for LinearRegression.", "2.0.0")
- def this() = this(1.0, 100, 0.01, 1.0)
-
override protected def createModel(weights: Vector, intercept: Double) = {
new LassoModel(weights, intercept)
}
}
-
-/**
- * Top-level methods for calling Lasso.
- *
- */
-@Since("0.8.0")
-@deprecated("Use ml.regression.LinearRegression with elasticNetParam = 1.0. Note the default " +
- "regParam is 0.01 for LassoWithSGD, but is 0.0 for LinearRegression.", "2.0.0")
-object LassoWithSGD {
-
- /**
- * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
- * of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate a stochastic gradient. The weights used
- * in gradient descent are initialized using the initial weights provided.
- *
- * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
- * matrix A as well as the corresponding right hand side label y
- * @param numIterations Number of iterations of gradient descent to run.
- * @param stepSize Step size scaling to be used for the iterations of gradient descent.
- * @param regParam Regularization parameter.
- * @param miniBatchFraction Fraction of data to be used per iteration.
- * @param initialWeights Initial set of weights to be used. Array should be equal in size to
- * the number of features in the data.
- *
- */
- @Since("1.0.0")
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double,
- regParam: Double,
- miniBatchFraction: Double,
- initialWeights: Vector): LassoModel = {
- new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction)
- .run(input, initialWeights)
- }
-
- /**
- * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
- * of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate a stochastic gradient.
- *
- * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
- * matrix A as well as the corresponding right hand side label y
- * @param numIterations Number of iterations of gradient descent to run.
- * @param stepSize Step size to be used for each iteration of gradient descent.
- * @param regParam Regularization parameter.
- * @param miniBatchFraction Fraction of data to be used per iteration.
- *
- */
- @Since("0.8.0")
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double,
- regParam: Double,
- miniBatchFraction: Double): LassoModel = {
- new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
- }
-
- /**
- * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
- * of iterations of gradient descent using the specified step size. We use the entire data set to
- * update the true gradient in each iteration.
- *
- * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
- * matrix A as well as the corresponding right hand side label y
- * @param stepSize Step size to be used for each iteration of Gradient Descent.
- * @param regParam Regularization parameter.
- * @param numIterations Number of iterations of gradient descent to run.
- * @return a LassoModel which has the weights and offset from training.
- *
- */
- @Since("0.8.0")
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double,
- regParam: Double): LassoModel = {
- train(input, numIterations, stepSize, regParam, 1.0)
- }
-
- /**
- * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
- * of iterations of gradient descent using a step size of 1.0. We use the entire data set to
- * compute the true gradient in each iteration.
- *
- * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
- * matrix A as well as the corresponding right hand side label y
- * @param numIterations Number of iterations of gradient descent to run.
- * @return a LassoModel which has the weights and offset from training.
- *
- */
- @Since("0.8.0")
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int): LassoModel = {
- train(input, numIterations, 1.0, 0.01, 1.0)
- }
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index cb08216fbf690..f68ebc17e294d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -24,7 +24,6 @@ import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression.impl.GLMRegressionModel
import org.apache.spark.mllib.util.{Loader, Saveable}
-import org.apache.spark.rdd.RDD
/**
* Regression model trained using LinearRegression.
@@ -100,109 +99,8 @@ class LinearRegressionWithSGD private[mllib] (
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
- /**
- * Construct a LinearRegression object with default parameters: {stepSize: 1.0,
- * numIterations: 100, miniBatchFraction: 1.0}.
- */
- @Since("0.8.0")
- @deprecated("Use ml.regression.LinearRegression or LBFGS", "2.0.0")
- def this() = this(1.0, 100, 0.0, 1.0)
-
override protected[mllib] def createModel(weights: Vector, intercept: Double) = {
new LinearRegressionModel(weights, intercept)
}
}
-/**
- * Top-level methods for calling LinearRegression.
- *
- */
-@Since("0.8.0")
-@deprecated("Use ml.regression.LinearRegression or LBFGS", "2.0.0")
-object LinearRegressionWithSGD {
-
- /**
- * Train a Linear Regression model given an RDD of (label, features) pairs. We run a fixed number
- * of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate a stochastic gradient. The weights used
- * in gradient descent are initialized using the initial weights provided.
- *
- * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
- * matrix A as well as the corresponding right hand side label y
- * @param numIterations Number of iterations of gradient descent to run.
- * @param stepSize Step size to be used for each iteration of gradient descent.
- * @param miniBatchFraction Fraction of data to be used per iteration.
- * @param initialWeights Initial set of weights to be used. Array should be equal in size to
- * the number of features in the data.
- *
- */
- @Since("1.0.0")
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double,
- miniBatchFraction: Double,
- initialWeights: Vector): LinearRegressionModel = {
- new LinearRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction)
- .run(input, initialWeights)
- }
-
- /**
- * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number
- * of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate a stochastic gradient.
- *
- * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
- * matrix A as well as the corresponding right hand side label y
- * @param numIterations Number of iterations of gradient descent to run.
- * @param stepSize Step size to be used for each iteration of gradient descent.
- * @param miniBatchFraction Fraction of data to be used per iteration.
- *
- */
- @Since("0.8.0")
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double,
- miniBatchFraction: Double): LinearRegressionModel = {
- new LinearRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(input)
- }
-
- /**
- * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number
- * of iterations of gradient descent using the specified step size. We use the entire data set to
- * compute the true gradient in each iteration.
- *
- * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
- * matrix A as well as the corresponding right hand side label y
- * @param stepSize Step size to be used for each iteration of Gradient Descent.
- * @param numIterations Number of iterations of gradient descent to run.
- * @return a LinearRegressionModel which has the weights and offset from training.
- *
- */
- @Since("0.8.0")
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double): LinearRegressionModel = {
- train(input, numIterations, stepSize, 1.0)
- }
-
- /**
- * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number
- * of iterations of gradient descent using a step size of 1.0. We use the entire data set to
- * compute the true gradient in each iteration.
- *
- * @param input RDD of (label, array of features) pairs. Each pair describes a row of the data
- * matrix A as well as the corresponding right hand side label y
- * @param numIterations Number of iterations of gradient descent to run.
- * @return a LinearRegressionModel which has the weights and offset from training.
- *
- */
- @Since("0.8.0")
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int): LinearRegressionModel = {
- train(input, numIterations, 1.0, 1.0)
- }
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index 43c3154dd053b..1c3bdceab1d14 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -24,8 +24,6 @@ import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression.impl.GLMRegressionModel
import org.apache.spark.mllib.util.{Loader, Saveable}
-import org.apache.spark.rdd.RDD
-
/**
* Regression model trained using RidgeRegression.
@@ -100,113 +98,7 @@ class RidgeRegressionWithSGD private[mllib] (
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
- /**
- * Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100,
- * regParam: 0.01, miniBatchFraction: 1.0}.
- */
- @Since("0.8.0")
- @deprecated("Use ml.regression.LinearRegression with elasticNetParam = 0.0. Note the default " +
- "regParam is 0.01 for RidgeRegressionWithSGD, but is 0.0 for LinearRegression.", "2.0.0")
- def this() = this(1.0, 100, 0.01, 1.0)
-
override protected def createModel(weights: Vector, intercept: Double) = {
new RidgeRegressionModel(weights, intercept)
}
}
-
-/**
- * Top-level methods for calling RidgeRegression.
- *
- */
-@Since("0.8.0")
-@deprecated("Use ml.regression.LinearRegression with elasticNetParam = 0.0. Note the default " +
- "regParam is 0.01 for RidgeRegressionWithSGD, but is 0.0 for LinearRegression.", "2.0.0")
-object RidgeRegressionWithSGD {
-
- /**
- * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number
- * of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate a stochastic gradient. The weights used
- * in gradient descent are initialized using the initial weights provided.
- *
- * @param input RDD of (label, array of features) pairs.
- * @param numIterations Number of iterations of gradient descent to run.
- * @param stepSize Step size to be used for each iteration of gradient descent.
- * @param regParam Regularization parameter.
- * @param miniBatchFraction Fraction of data to be used per iteration.
- * @param initialWeights Initial set of weights to be used. Array should be equal in size to
- * the number of features in the data.
- *
- */
- @Since("1.0.0")
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double,
- regParam: Double,
- miniBatchFraction: Double,
- initialWeights: Vector): RidgeRegressionModel = {
- new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(
- input, initialWeights)
- }
-
- /**
- * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number
- * of iterations of gradient descent using the specified step size. Each iteration uses
- * `miniBatchFraction` fraction of the data to calculate a stochastic gradient.
- *
- * @param input RDD of (label, array of features) pairs.
- * @param numIterations Number of iterations of gradient descent to run.
- * @param stepSize Step size to be used for each iteration of gradient descent.
- * @param regParam Regularization parameter.
- * @param miniBatchFraction Fraction of data to be used per iteration.
- *
- */
- @Since("0.8.0")
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double,
- regParam: Double,
- miniBatchFraction: Double): RidgeRegressionModel = {
- new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
- }
-
- /**
- * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number
- * of iterations of gradient descent using the specified step size. We use the entire data set to
- * compute the true gradient in each iteration.
- *
- * @param input RDD of (label, array of features) pairs.
- * @param stepSize Step size to be used for each iteration of Gradient Descent.
- * @param regParam Regularization parameter.
- * @param numIterations Number of iterations of gradient descent to run.
- * @return a RidgeRegressionModel which has the weights and offset from training.
- *
- */
- @Since("0.8.0")
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int,
- stepSize: Double,
- regParam: Double): RidgeRegressionModel = {
- train(input, numIterations, stepSize, regParam, 1.0)
- }
-
- /**
- * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number
- * of iterations of gradient descent using a step size of 1.0. We use the entire data set to
- * compute the true gradient in each iteration.
- *
- * @param input RDD of (label, array of features) pairs.
- * @param numIterations Number of iterations of gradient descent to run.
- * @return a RidgeRegressionModel which has the weights and offset from training.
- *
- */
- @Since("0.8.0")
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int): RidgeRegressionModel = {
- train(input, numIterations, 1.0, 0.01, 1.0)
- }
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
index 7f84be9f37822..b6eb10e9de00a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala
@@ -65,7 +65,7 @@ object KMeansDataGenerator {
}
@Since("0.8.0")
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 6) {
// scalastyle:off println
println("Usage: KMeansGenerator " +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
index 58fd010e4905f..c218681b3375e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
@@ -189,7 +189,7 @@ object LinearDataGenerator {
}
@Since("0.8.0")
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 2) {
// scalastyle:off println
println("Usage: LinearDataGenerator " +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
index 68835bc79677f..7e9d9465441c9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
@@ -65,7 +65,7 @@ object LogisticRegressionDataGenerator {
}
@Since("0.8.0")
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length != 5) {
// scalastyle:off println
println("Usage: LogisticRegressionGenerator " +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
index 42c5bcdd39f76..7a308a5ec25c0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
@@ -54,7 +54,7 @@ import org.apache.spark.rdd.RDD
@Since("0.8.0")
object MFDataGenerator {
@Since("0.8.0")
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 2) {
// scalastyle:off println
println("Usage: MFDataGenerator " +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 6d15a6bb01e4e..9198334ba02a1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -173,7 +173,7 @@ object MLUtils extends Logging {
* @see `org.apache.spark.mllib.util.MLUtils.loadLibSVMFile`
*/
@Since("1.0.0")
- def saveAsLibSVMFile(data: RDD[LabeledPoint], dir: String) {
+ def saveAsLibSVMFile(data: RDD[LabeledPoint], dir: String): Unit = {
// TODO: allow to specify label precision and feature precision.
val dataStr = data.map { case LabeledPoint(label, features) =>
val sb = new StringBuilder(label.toString)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
index c9468606544db..9f6ba025aedde 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
@@ -37,7 +37,7 @@ import org.apache.spark.rdd.RDD
object SVMDataGenerator {
@Since("0.8.0")
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
if (args.length < 2) {
// scalastyle:off println
println("Usage: SVMGenerator " +
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
index b7956b6fd3e9a..69952f0b64ac2 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
@@ -20,7 +20,7 @@
import java.util.Arrays;
import java.util.List;
-import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
+import org.jtransforms.dct.DoubleDCT_1D;
import org.junit.Assert;
import org.junit.Test;
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
index c04e2e69541ba..208a5aaa2bb15 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
@@ -50,11 +50,8 @@ public void runLRUsingConstructor() {
List validationData =
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
- LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD();
+ LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD(1.0, 100, 1.0, 1.0);
lrImpl.setIntercept(true);
- lrImpl.optimizer().setStepSize(1.0)
- .setRegParam(1.0)
- .setNumIterations(100);
LogisticRegressionModel model = lrImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
@@ -72,8 +69,8 @@ public void runLRUsingStaticMethods() {
List validationData =
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
- LogisticRegressionModel model = LogisticRegressionWithSGD.train(
- testRDD.rdd(), 100, 1.0, 1.0);
+ LogisticRegressionModel model = new LogisticRegressionWithSGD(1.0, 100, 0.01, 1.0)
+ .run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
index 270e636f82117..a9a8b7f2b88d6 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
@@ -42,11 +42,11 @@ public void runKMeansUsingStaticMethods() {
Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
JavaRDD data = jsc.parallelize(points, 2);
- KMeansModel model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.K_MEANS_PARALLEL());
+ KMeansModel model = KMeans.train(data.rdd(), 1, 1, KMeans.K_MEANS_PARALLEL());
assertEquals(1, model.clusterCenters().length);
assertEquals(expectedCenter, model.clusterCenters()[0]);
- model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.RANDOM());
+ model = KMeans.train(data.rdd(), 1, 1, KMeans.RANDOM());
assertEquals(expectedCenter, model.clusterCenters()[0]);
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
index 1458cc72bc17f..35ad24bc2a84f 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
@@ -51,10 +51,7 @@ public void runLassoUsingConstructor() {
List validationData =
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
- LassoWithSGD lassoSGDImpl = new LassoWithSGD();
- lassoSGDImpl.optimizer().setStepSize(1.0)
- .setRegParam(0.01)
- .setNumIterations(20);
+ LassoWithSGD lassoSGDImpl = new LassoWithSGD(1.0, 20, 0.01, 1.0);
LassoModel model = lassoSGDImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
@@ -72,7 +69,7 @@ public void runLassoUsingStaticMethods() {
List validationData =
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
- LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0);
+ LassoModel model = new LassoWithSGD(1.0, 100, 0.01, 1.0).run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
index 86c723aa00746..7e87588c4f0f6 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
@@ -33,7 +33,7 @@ private static int validatePrediction(
List validationData, LinearRegressionModel model) {
int numAccurate = 0;
for (LabeledPoint point : validationData) {
- Double prediction = model.predict(point.features());
+ double prediction = model.predict(point.features());
// A prediction is off if the prediction is more than 0.5 away from expected value.
if (Math.abs(prediction - point.label()) <= 0.5) {
numAccurate++;
@@ -53,7 +53,7 @@ public void runLinearRegressionUsingConstructor() {
List validationData =
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
- LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
+ LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0);
linSGDImpl.setIntercept(true);
LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
@@ -72,7 +72,8 @@ public void runLinearRegressionUsingStaticMethods() {
List validationData =
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
- LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100);
+ LinearRegressionModel model = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0)
+ .run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
@@ -85,7 +86,7 @@ public void testPredictJavaRDD() {
double[] weights = {10, 10};
JavaRDD testRDD = jsc.parallelize(
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
- LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
+ LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0);
LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
JavaRDD vectors = testRDD.map(LabeledPoint::features);
JavaRDD predictions = model.predict(vectors);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
index 5a9389c424b44..63441950cd18f 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
@@ -34,7 +34,7 @@ private static double predictionError(List validationData,
RidgeRegressionModel model) {
double errorSum = 0;
for (LabeledPoint point : validationData) {
- Double prediction = model.predict(point.features());
+ double prediction = model.predict(point.features());
errorSum += (prediction - point.label()) * (prediction - point.label());
}
return errorSum / validationData.size();
@@ -60,11 +60,7 @@ public void runRidgeRegressionUsingConstructor() {
new ArrayList<>(data.subList(0, numExamples)));
List validationData = data.subList(numExamples, 2 * numExamples);
- RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
- ridgeSGDImpl.optimizer()
- .setStepSize(1.0)
- .setRegParam(0.0)
- .setNumIterations(200);
+ RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(1.0, 200, 0.0, 1.0);
RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd());
double unRegularizedErr = predictionError(validationData, model);
@@ -85,10 +81,12 @@ public void runRidgeRegressionUsingStaticMethods() {
new ArrayList<>(data.subList(0, numExamples)));
List validationData = data.subList(numExamples, 2 * numExamples);
- RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0);
+ RidgeRegressionModel model = new RidgeRegressionWithSGD(1.0, 200, 0.0, 1.0)
+ .run(testRDD.rdd());
double unRegularizedErr = predictionError(validationData, model);
- model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.1);
+ model = new RidgeRegressionWithSGD(1.0, 200, 0.1, 1.0)
+ .run(testRDD.rdd());
double regularizedErr = predictionError(validationData, model);
Assert.assertTrue(regularizedErr < unRegularizedErr);
diff --git a/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala
index e2ee7c05ab399..f2343b7a88560 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala
@@ -25,7 +25,7 @@ import org.mockito.ArgumentMatchers.{any, eq => meq}
import org.mockito.Mockito.when
import org.scalatest.BeforeAndAfterEach
import org.scalatest.concurrent.Eventually
-import org.scalatest.mockito.MockitoSugar.mock
+import org.scalatestplus.mockito.MockitoSugar.mock
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamMap
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 1183cb0617610..e6025a5a53ca6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path
import org.mockito.ArgumentMatchers.{any, eq => meq}
import org.mockito.Mockito.when
-import org.scalatest.mockito.MockitoSugar.mock
+import org.scalatestplus.mockito.MockitoSugar.mock
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.Pipeline.SharedReadWrite
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 9f2053dcc91fc..3ebf8a83a892c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -44,7 +44,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
private val seed = 42
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
categoricalDataPointsRDD =
sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()).map(_.asML)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 467f13f808a01..af3dd201d3b51 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -55,7 +55,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
private val eps: Double = 1e-5
private val absEps: Double = 1e-8
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2)
.map(_.asML)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 0f0954e5d8cac..f03ed0b76eb80 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -42,7 +42,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _
private var orderedLabeledPoints5_20: RDD[LabeledPoint] = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
orderedLabeledPoints50_1000 =
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000))
@@ -56,7 +56,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
// Tests calling train()
/////////////////////////////////////////////////////////////////////////////
- def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier) {
+ def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier): Unit = {
val categoricalFeatures = Map.empty[Int, Int]
val numClasses = 2
val newRF = rf
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index c1a156959618e..f4f858c3e92dc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -76,6 +76,10 @@ class RegressionEvaluatorSuite
// mae
evaluator.setMetricName("mae")
assert(evaluator.evaluate(predictions) ~== 0.08399089 absTol 0.01)
+
+ // var
+ evaluator.setMetricName("var")
+ assert(evaluator.evaluate(predictions) ~== 63.6944519 absTol 0.01)
}
test("read/write") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 05d4a6ee2dabf..91bec50fb904f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -101,6 +101,20 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest {
}
}
+ test("Binarizer should support sparse vector with negative threshold") {
+ val data = Seq(
+ (Vectors.sparse(3, Array(1), Array(0.5)), Vectors.dense(Array(1.0, 1.0, 1.0))),
+ (Vectors.dense(Array(0.0, 0.5, 0.0)), Vectors.dense(Array(1.0, 1.0, 1.0))))
+ val df = data.toDF("feature", "expected")
+ val binarizer = new Binarizer()
+ .setInputCol("feature")
+ .setOutputCol("binarized_feature")
+ .setThreshold(-0.5)
+ binarizer.transform(df).select("binarized_feature", "expected").collect().foreach {
+ case Row(x: Vector, y: Vector) =>
+ assert(x == y, "The feature value is not correct after binarization.")
+ }
+ }
test("read/write") {
val t = new Binarizer()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
index 985e396000d05..079dabb3665be 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.feature
-import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D
+import org.jtransforms.dct.DoubleDCT_1D
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index ae086d32d6d0b..6f6ab26cbac43 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.ml.Pipeline
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql._
@@ -423,33 +424,92 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest {
assert(readDiscretizer.hasDefault(readDiscretizer.outputCol))
}
- test("Multiple Columns: Both inputCol and inputCols are set") {
+ test("Multiple Columns: Mismatched sizes of inputCols/outputCols") {
val spark = this.spark
import spark.implicits._
val discretizer = new QuantileDiscretizer()
- .setInputCol("input")
- .setOutputCol("result")
+ .setInputCols(Array("input"))
+ .setOutputCols(Array("result1", "result2"))
.setNumBuckets(3)
- .setInputCols(Array("input1", "input2"))
val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
.map(Tuple1.apply).toDF("input")
- // When both inputCol and inputCols are set, we throw Exception.
intercept[IllegalArgumentException] {
discretizer.fit(df)
}
}
- test("Multiple Columns: Mismatched sizes of inputCols / outputCols") {
+ test("Multiple Columns: Mismatched sizes of inputCols/numBucketsArray") {
val spark = this.spark
import spark.implicits._
val discretizer = new QuantileDiscretizer()
- .setInputCols(Array("input"))
+ .setInputCols(Array("input1", "input2"))
.setOutputCols(Array("result1", "result2"))
- .setNumBuckets(3)
+ .setNumBucketsArray(Array(2, 5, 10))
+ val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0)
+ val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0)
+ val df = data1.zip(data2).toSeq.toDF("input1", "input2")
+ intercept[IllegalArgumentException] {
+ discretizer.fit(df)
+ }
+ }
+
+ test("Multiple Columns: Set both of numBuckets/numBucketsArray") {
+ val spark = this.spark
+ import spark.implicits._
+ val discretizer = new QuantileDiscretizer()
+ .setInputCols(Array("input1", "input2"))
+ .setOutputCols(Array("result1", "result2"))
+ .setNumBucketsArray(Array(2, 5))
+ .setNumBuckets(2)
+ val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0)
+ val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0)
+ val df = data1.zip(data2).toSeq.toDF("input1", "input2")
+ intercept[IllegalArgumentException] {
+ discretizer.fit(df)
+ }
+ }
+
+ test("Setting numBucketsArray for Single-Column QuantileDiscretizer") {
+ val spark = this.spark
+ import spark.implicits._
+ val discretizer = new QuantileDiscretizer()
+ .setInputCol("input")
+ .setOutputCol("result")
+ .setNumBucketsArray(Array(2, 5))
val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
.map(Tuple1.apply).toDF("input")
intercept[IllegalArgumentException] {
discretizer.fit(df)
}
}
+
+ test("Assert exception is thrown if both multi-column and single-column params are set") {
+ val spark = this.spark
+ import spark.implicits._
+ val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2")
+ ParamsSuite.testExclusiveParams(new QuantileDiscretizer, df, ("inputCol", "feature1"),
+ ("inputCols", Array("feature1", "feature2")))
+ ParamsSuite.testExclusiveParams(new QuantileDiscretizer, df, ("inputCol", "feature1"),
+ ("outputCol", "result1"), ("outputCols", Array("result1", "result2")))
+ // this should fail because at least one of inputCol and inputCols must be set
+ ParamsSuite.testExclusiveParams(new QuantileDiscretizer, df, ("outputCol", "feature1"))
+ }
+
+ test("Setting inputCol without setting outputCol") {
+ val spark = this.spark
+ import spark.implicits._
+
+ val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
+ .map(Tuple1.apply).toDF("input")
+ val numBuckets = 2
+ val discretizer = new QuantileDiscretizer()
+ .setInputCol("input")
+ .setNumBuckets(numBuckets)
+ val model = discretizer.fit(df)
+ val result = model.transform(df)
+
+ val observedNumBuckets = result.select(discretizer.getOutputCol).distinct.count
+ assert(observedNumBuckets === numBuckets,
+ "Observed number of buckets does not equal expected number of buckets.")
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
index add1cc17ea057..efd56f7073a19 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
@@ -25,7 +25,7 @@ class RFormulaParserSuite extends SparkFunSuite {
formula: String,
label: String,
terms: Seq[String],
- schema: StructType = new StructType) {
+ schema: StructType = new StructType): Unit = {
val resolved = RFormulaParser.parse(formula).resolve(schema)
assert(resolved.label == label)
val simpleTerms = terms.map { t =>
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 630e785e59507..49ebcb385640e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -40,7 +40,7 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
private val seed = 42
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
categoricalDataPointsRDD =
sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints().map(_.asML))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 884fe2d11bf5a..60007975c3b52 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -47,7 +47,7 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
private var trainData: RDD[LabeledPoint] = _
private var validationData: RDD[LabeledPoint] = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2)
.map(_.asML)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index c6dabd1b28829..0243e8d2335ee 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -38,7 +38,7 @@ class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{
private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
orderedLabeledPoints50_1000 =
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
@@ -49,7 +49,7 @@ class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{
// Tests calling train()
/////////////////////////////////////////////////////////////////////////////
- def regressionTestWithContinuousFeatures(rf: RandomForestRegressor) {
+ def regressionTestWithContinuousFeatures(rf: RandomForestRegressor): Unit = {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val newRF = rf
.setImpurity("variance")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index a63ab913f2c22..ae5e979983b4f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -485,7 +485,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
- def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: OldStrategy) {
+ def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(
+ strategy: OldStrategy): Unit = {
val numFeatures = 50
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
val rdd = sc.parallelize(arr).map(_.asML.toInstance)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
index 8a0a48ff6095b..90079c9848823 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
@@ -56,7 +56,7 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
sc.setCheckpointDir(checkpointDir)
}
- override def afterAll() {
+ override def afterAll(): Unit = {
try {
Utils.deleteRecursively(new File(checkpointDir))
} finally {
@@ -127,7 +127,7 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
dataframe: DataFrame,
transformer: Transformer,
expectedMessagePart : String,
- firstResultCol: String) {
+ firstResultCol: String): Unit = {
withClue(s"""Expected message part "${expectedMessagePart}" is not found in DF test.""") {
val exceptionOnDf = intercept[Throwable] {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 5cf4377768516..d4e9da3c6263e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -206,7 +206,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
def validatePrediction(
predictions: Seq[Double],
input: Seq[LabeledPoint],
- expectedAcc: Double = 0.83) {
+ expectedAcc: Double = 0.83): Unit = {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
prediction != expected.label
}
@@ -224,12 +224,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val lr = new LogisticRegressionWithSGD().setIntercept(true)
- lr.optimizer
- .setStepSize(10.0)
- .setRegParam(0.0)
- .setNumIterations(20)
- .setConvergenceTol(0.0005)
+ val lr = new LogisticRegressionWithSGD(10.0, 20, 0.0, 1.0).setIntercept(true)
+ lr.optimizer.setConvergenceTol(0.0005)
val model = lr.run(testRDD)
@@ -300,11 +296,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
testRDD.cache()
// Use half as many iterations as the previous test.
- val lr = new LogisticRegressionWithSGD().setIntercept(true)
- lr.optimizer
- .setStepSize(10.0)
- .setRegParam(0.0)
- .setNumIterations(10)
+ val lr = new LogisticRegressionWithSGD(10.0, 10, 0.0, 1.0).setIntercept(true)
val model = lr.run(testRDD, initialWeights)
@@ -335,11 +327,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
testRDD.cache()
// Use half as many iterations as the previous test.
- val lr = new LogisticRegressionWithSGD().setIntercept(true)
- lr.optimizer.
- setStepSize(1.0).
- setNumIterations(10).
- setRegParam(1.0)
+ val lr = new LogisticRegressionWithSGD(1.0, 10, 1.0, 1.0).setIntercept(true)
val model = lr.run(testRDD, initialWeights)
@@ -916,7 +904,7 @@ class LogisticRegressionClusterSuite extends SparkFunSuite with LocalClusterSpar
}.cache()
// If we serialize data directly in the task closure, the size of the serialized task would be
// greater than 1MB and hence Spark would throw an error.
- val model = LogisticRegressionWithSGD.train(points, 2)
+ val model = new LogisticRegressionWithSGD(1.0, 2, 0.0, 1.0).run(points)
val predictions = model.predict(points.map(_.features))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 725389813b3e2..47dac3ec29a5c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -91,7 +91,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
import NaiveBayes.{Multinomial, Bernoulli}
- def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
+ def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]): Unit = {
val numOfPredictions = predictions.zip(input).count {
case (prediction, expected) =>
prediction != expected.label
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 3676d9c5debc8..007b8ae6e1a6a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -62,7 +62,7 @@ object SVMSuite {
class SVMSuite extends SparkFunSuite with MLlibTestSparkContext {
- def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
+ def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]): Unit = {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
prediction != expected.label
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
index 5f797a60f09e6..7349e0319324a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
@@ -23,23 +23,17 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
+import org.apache.spark.streaming.{LocalStreamingContext, TestSuiteBase}
import org.apache.spark.streaming.dstream.DStream
-class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase {
+class StreamingLogisticRegressionSuite
+ extends SparkFunSuite
+ with LocalStreamingContext
+ with TestSuiteBase {
// use longer wait time to ensure job completion
override def maxWaitTimeMillis: Int = 30000
- var ssc: StreamingContext = _
-
- override def afterFunction() {
- super.afterFunction()
- if (ssc != null) {
- ssc.stop()
- }
- }
-
// Test if we can accurately learn B for Y = logistic(BX) on streaming data
test("parameter accuracy") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index c4bf5b27187f6..149a525a58ff6 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -367,7 +367,7 @@ class KMeansClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
for (initMode <- Seq(KMeans.RANDOM, KMeans.K_MEANS_PARALLEL)) {
// If we serialize data directly in the task closure, the size of the serialized task would be
// greater than 1MB and hence Spark would throw an error.
- val model = KMeans.train(points, 2, 2, 1, initMode)
+ val model = KMeans.train(points, 2, 2, initMode)
val predictions = model.predict(points).collect()
val cost = model.computeCost(points)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
index a1ac10c06c697..415ac87275390 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
@@ -20,23 +20,14 @@ package org.apache.spark.mllib.clustering
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
+import org.apache.spark.streaming.{LocalStreamingContext, TestSuiteBase}
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.random.XORShiftRandom
-class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
+class StreamingKMeansSuite extends SparkFunSuite with LocalStreamingContext with TestSuiteBase {
override def maxWaitTimeMillis: Int = 30000
- var ssc: StreamingContext = _
-
- override def afterFunction() {
- super.afterFunction()
- if (ssc != null) {
- ssc.stop()
- }
- }
-
test("accuracy for single center and equivalence to grand average") {
// set parameters
val numBatches = 10
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index b4520d42fedf5..184c89c9eaaf9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -24,7 +24,7 @@ import scala.reflect.ClassTag
import breeze.linalg.{CSCMatrix, Matrix => BM}
import org.mockito.Mockito.when
-import org.scalatest.mockito.MockitoSugar._
+import org.scalatestplus.mockito.MockitoSugar._
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.internal.config.Kryo._
@@ -39,7 +39,7 @@ class MatricesSuite extends SparkFunSuite {
val ser = new KryoSerializer(conf).newInstance()
- def check[T: ClassTag](t: T) {
+ def check[T: ClassTag](t: T): Unit = {
assert(ser.deserialize[T](ser.serialize(t)) === t)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index fee0b02bf8ed8..c0c5c5c7d98d5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -42,7 +42,7 @@ class VectorsSuite extends SparkFunSuite with Logging {
conf.set(KRYO_REGISTRATION_REQUIRED, true)
val ser = new KryoSerializer(conf).newInstance()
- def check[T: ClassTag](t: T) {
+ def check[T: ClassTag](t: T): Unit = {
assert(ser.deserialize[T](ser.serialize(t)) === t)
}
@@ -510,4 +510,27 @@ class VectorsSuite extends SparkFunSuite with Logging {
Vectors.sparse(-1, Array((1, 2.0)))
}
}
+
+ test("dot product only supports vectors of same size") {
+ val vSize4 = Vectors.dense(arr)
+ val vSize1 = Vectors.zeros(1)
+ intercept[IllegalArgumentException]{ vSize1.dot(vSize4) }
+ }
+
+ test("dense vector dot product") {
+ val dv = Vectors.dense(arr)
+ assert(dv.dot(dv) === 0.26)
+ }
+
+ test("sparse vector dot product") {
+ val sv = Vectors.sparse(n, indices, values)
+ assert(sv.dot(sv) === 0.26)
+ }
+
+ test("mixed sparse and dense vector dot product") {
+ val sv = Vectors.sparse(n, indices, values)
+ val dv = Vectors.dense(arr)
+ assert(sv.dot(dv) === 0.26)
+ assert(dv.dot(sv) === 0.26)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
index f6a996940291c..9d7177e0a149e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
@@ -35,7 +35,7 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
val numPartitions = 3
var gridBasedMat: BlockMatrix = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
val blocks: Seq[((Int, Int), Matrix)] = Seq(
((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))),
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
index 37d75103d18d2..d197f06a393e8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
@@ -29,7 +29,7 @@ class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
val n = 4
var mat: CoordinateMatrix = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
val entries = sc.parallelize(Seq(
(0, 0, 1.0),
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
index cca4eb4e4260e..e961d10711860 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
@@ -36,7 +36,7 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
).map(x => IndexedRow(x._1, x._2))
var indexedRows: RDD[IndexedRow] = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
indexedRows = sc.parallelize(data, 2)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
index a0c4c68243e67..0a4b11935580a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
@@ -57,7 +57,7 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
var denseMat: RowMatrix = _
var sparseMat: RowMatrix = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
denseMat = new RowMatrix(sc.parallelize(denseData, 2))
sparseMat = new RowMatrix(sc.parallelize(sparseData, 2))
@@ -213,7 +213,7 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
brzNorm(v, 1.0) < 1e-6
}
- def assertColumnEqualUpToSign(A: BDM[Double], B: BDM[Double], k: Int) {
+ def assertColumnEqualUpToSign(A: BDM[Double], B: BDM[Double], k: Int): Unit = {
assert(A.rows === B.rows)
for (j <- 0 until k) {
val aj = A(::, j)
@@ -338,7 +338,7 @@ class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext
var mat: RowMatrix = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
val m = 4
val n = 200000
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
index b3bf5a2a8f2cc..a629c6951abcd 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.util.StatCounter
class RandomDataGeneratorSuite extends SparkFunSuite {
- def apiChecks(gen: RandomDataGenerator[Double]) {
+ def apiChecks(gen: RandomDataGenerator[Double]): Unit = {
// resetting seed should generate the same sequence of random numbers
gen.setSeed(42L)
val array1 = (0 until 1000).map(_ => gen.nextValue())
@@ -56,7 +56,7 @@ class RandomDataGeneratorSuite extends SparkFunSuite {
def distributionChecks(gen: RandomDataGenerator[Double],
mean: Double = 0.0,
stddev: Double = 1.0,
- epsilon: Double = 0.01) {
+ epsilon: Double = 0.01): Unit = {
for (seed <- 0 until 5) {
gen.setSeed(seed.toLong)
val sample = (0 until 100000).map { _ => gen.nextValue()}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
index 9b4dc29d326a1..470e1016dab39 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
@@ -38,7 +38,7 @@ class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Seri
expectedNumPartitions: Int,
expectedMean: Double,
expectedStddev: Double,
- epsilon: Double = 0.01) {
+ epsilon: Double = 0.01): Unit = {
val stats = rdd.stats()
assert(expectedSize === stats.count)
assert(expectedNumPartitions === rdd.partitions.size)
@@ -53,7 +53,7 @@ class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Seri
expectedNumPartitions: Int,
expectedMean: Double,
expectedStddev: Double,
- epsilon: Double = 0.01) {
+ epsilon: Double = 0.01): Unit = {
assert(expectedNumPartitions === rdd.partitions.size)
val values = new ArrayBuffer[Double]()
rdd.collect.foreach { vector => {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index b08ad99f4f204..9be87db873dad 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -224,7 +224,7 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext {
negativeWeights: Boolean = false,
numUserBlocks: Int = -1,
numProductBlocks: Int = -1,
- negativeFactors: Boolean = true) {
+ negativeFactors: Boolean = true): Unit = {
// scalastyle:on
val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products,
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index d96103d01e4ab..f336dac0ccb5d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -33,7 +33,7 @@ private object LassoSuite {
class LassoSuite extends SparkFunSuite with MLlibTestSparkContext {
- def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
+ def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]): Unit = {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
// A prediction is off if the prediction is more than 0.5 away from expected value.
math.abs(prediction - expected.label) > 0.5
@@ -55,8 +55,7 @@ class LassoSuite extends SparkFunSuite with MLlibTestSparkContext {
}
val testRDD = sc.parallelize(testData, 2).cache()
- val ls = new LassoWithSGD()
- ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40)
+ val ls = new LassoWithSGD(1.0, 40, 0.01, 1.0)
val model = ls.run(testRDD)
val weight0 = model.weights(0)
@@ -99,8 +98,8 @@ class LassoSuite extends SparkFunSuite with MLlibTestSparkContext {
val testRDD = sc.parallelize(testData, 2).cache()
- val ls = new LassoWithSGD()
- ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40).setConvergenceTol(0.0005)
+ val ls = new LassoWithSGD(1.0, 40, 0.01, 1.0)
+ ls.optimizer.setConvergenceTol(0.0005)
val model = ls.run(testRDD, initialWeights)
val weight0 = model.weights(0)
@@ -153,7 +152,7 @@ class LassoClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
}.cache()
// If we serialize data directly in the task closure, the size of the serialized task would be
// greater than 1MB and hence Spark would throw an error.
- val model = LassoWithSGD.train(points, 2)
+ val model = new LassoWithSGD(1.0, 2, 0.01, 1.0).run(points)
val predictions = model.predict(points.map(_.features))
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 0694079b9df9e..be0834d0fd7df 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -33,7 +33,7 @@ private object LinearRegressionSuite {
class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
- def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
+ def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]): Unit = {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
// A prediction is off if the prediction is more than 0.5 away from expected value.
math.abs(prediction - expected.label) > 0.5
@@ -46,7 +46,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
test("linear regression") {
val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
3.0, Array(10.0, 10.0), 100, 42), 2).cache()
- val linReg = new LinearRegressionWithSGD().setIntercept(true)
+ val linReg = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0).setIntercept(true)
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
val model = linReg.run(testRDD)
@@ -72,7 +72,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
test("linear regression without intercept") {
val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
0.0, Array(10.0, 10.0), 100, 42), 2).cache()
- val linReg = new LinearRegressionWithSGD().setIntercept(false)
+ val linReg = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0).setIntercept(false)
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
val model = linReg.run(testRDD)
@@ -103,7 +103,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1))))
LabeledPoint(label, sv)
}.cache()
- val linReg = new LinearRegressionWithSGD().setIntercept(false)
+ val linReg = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0).setIntercept(false)
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
val model = linReg.run(sparseRDD)
@@ -160,7 +160,7 @@ class LinearRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkC
}.cache()
// If we serialize data directly in the task closure, the size of the serialized task would be
// greater than 1MB and hence Spark would throw an error.
- val model = LinearRegressionWithSGD.train(points, 2)
+ val model = new LinearRegressionWithSGD(1.0, 2, 0.0, 1.0).run(points)
val predictions = model.predict(points.map(_.features))
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index 815be32d2e510..2d6aec184ad9d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -60,18 +60,13 @@ class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val validationRDD = sc.parallelize(validationData, 2).cache()
// First run without regularization.
- val linearReg = new LinearRegressionWithSGD()
- linearReg.optimizer.setNumIterations(200)
- .setStepSize(1.0)
+ val linearReg = new LinearRegressionWithSGD(1.0, 200, 0.0, 1.0)
val linearModel = linearReg.run(testRDD)
val linearErr = predictionError(
linearModel.predict(validationRDD.map(_.features)).collect(), validationData)
- val ridgeReg = new RidgeRegressionWithSGD()
- ridgeReg.optimizer.setNumIterations(200)
- .setRegParam(0.1)
- .setStepSize(1.0)
+ val ridgeReg = new RidgeRegressionWithSGD(1.0, 200, 0.1, 1.0)
val ridgeModel = ridgeReg.run(testRDD)
val ridgeErr = predictionError(
ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData)
@@ -110,7 +105,7 @@ class RidgeRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkCo
}.cache()
// If we serialize data directly in the task closure, the size of the serialized task would be
// greater than 1MB and hence Spark would throw an error.
- val model = RidgeRegressionWithSGD.train(points, 2)
+ val model = new RidgeRegressionWithSGD(1.0, 2, 0.01, 1.0).run(points)
val predictions = model.predict(points.map(_.features))
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index eaeaa3fc1e68d..8e2d7d10f2ce2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -22,31 +22,25 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LinearDataGenerator
-import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
+import org.apache.spark.streaming.{LocalStreamingContext, TestSuiteBase}
import org.apache.spark.streaming.dstream.DStream
-class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
+class StreamingLinearRegressionSuite
+ extends SparkFunSuite
+ with LocalStreamingContext
+ with TestSuiteBase {
// use longer wait time to ensure job completion
override def maxWaitTimeMillis: Int = 20000
- var ssc: StreamingContext = _
-
- override def afterFunction() {
- super.afterFunction()
- if (ssc != null) {
- ssc.stop()
- }
- }
-
// Assert that two values are equal within tolerance epsilon
- def assertEqual(v1: Double, v2: Double, epsilon: Double) {
+ def assertEqual(v1: Double, v2: Double, epsilon: Double): Unit = {
def errorMessage = v1.toString + " did not equal " + v2.toString
assert(math.abs(v1-v2) <= epsilon, errorMessage)
}
// Assert that model predictions are correct
- def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
+ def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]): Unit = {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
// A prediction is off if the prediction is more than 0.5 away from expected value.
math.abs(prediction - expected.label) > 0.5
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 88b9d4c039ba9..b738236473230 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -437,7 +437,7 @@ object DecisionTreeSuite extends SparkFunSuite {
def validateClassifier(
model: DecisionTreeModel,
input: Seq[LabeledPoint],
- requiredAccuracy: Double) {
+ requiredAccuracy: Double): Unit = {
val predictions = input.map(x => model.predict(x.features))
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
prediction != expected.label
@@ -450,7 +450,7 @@ object DecisionTreeSuite extends SparkFunSuite {
def validateRegressor(
model: DecisionTreeModel,
input: Seq[LabeledPoint],
- requiredMSE: Double) {
+ requiredMSE: Double): Unit = {
val predictions = input.map(x => model.predict(x.features))
val squaredError = predictions.zip(input).map { case (prediction, expected) =>
val err = prediction - expected.label
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
index d43e62bb65535..e04d7b7c327a8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
@@ -37,7 +37,7 @@ object EnsembleTestHelper {
numCols: Int,
expectedMean: Double,
expectedStddev: Double,
- epsilon: Double) {
+ epsilon: Double): Unit = {
val values = new mutable.ArrayBuffer[Double]()
data.foreach { row =>
assert(row.size == numCols)
@@ -51,7 +51,7 @@ object EnsembleTestHelper {
def validateClassifier(
model: TreeEnsembleModel,
input: Seq[LabeledPoint],
- requiredAccuracy: Double) {
+ requiredAccuracy: Double): Unit = {
val predictions = input.map(x => model.predict(x.features))
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
prediction != expected.label
@@ -68,7 +68,7 @@ object EnsembleTestHelper {
model: TreeEnsembleModel,
input: Seq[LabeledPoint],
required: Double,
- metricName: String = "mse") {
+ metricName: String = "mse"): Unit = {
val predictions = input.map(x => model.predict(x.features))
val errors = predictions.zip(input).map { case (prediction, point) =>
point.label - prediction
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index bec61ba6a003c..b1a385a576cea 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.util.Utils
* Test suite for [[RandomForest]].
*/
class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
- def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) {
+ def binaryClassificationTestWithContinuousFeatures(strategy: Strategy): Unit = {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
val numTrees = 1
@@ -68,7 +68,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
binaryClassificationTestWithContinuousFeatures(strategy)
}
- def regressionTestWithContinuousFeatures(strategy: Strategy) {
+ def regressionTestWithContinuousFeatures(strategy: Strategy): Unit = {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
val numTrees = 1
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
index 2853b752cb85c..79d4785fd6fa7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
@@ -25,7 +25,7 @@ import org.apache.spark.internal.config.Network.RPC_MESSAGE_MAX_SIZE
trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite =>
@transient var sc: SparkContext = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
val conf = new SparkConf()
.setMaster("local-cluster[2, 1, 1024]")
@@ -34,7 +34,7 @@ trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite =>
sc = new SparkContext(conf)
}
- override def afterAll() {
+ override def afterAll(): Unit = {
try {
if (sc != null) {
sc.stop()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
index 720237bd2dddd..f9a3cd088314e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
@@ -31,7 +31,7 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite =>
@transient var sc: SparkContext = _
@transient var checkpointDir: String = _
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
spark = SparkSession.builder
.master("local[2]")
@@ -43,7 +43,7 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite =>
sc.setCheckpointDir(checkpointDir)
}
- override def afterAll() {
+ override def afterAll(): Unit = {
try {
Utils.deleteRecursively(new File(checkpointDir))
SparkSession.clearActiveSession()
diff --git a/pom.xml b/pom.xml
index 6c474f5f7a3e7..9c2aa9de85ce6 100644
--- a/pom.xml
+++ b/pom.xml
@@ -125,7 +125,7 @@
2.7.4
2.5.0
${hadoop.version}
- 3.4.6
+ 3.4.14
2.7.1
0.4.2
org.spark-project.hive
@@ -139,7 +139,7 @@
2.3.0
10.12.1.1
1.10.1
- 1.5.5
+ 1.5.6
nohive
com.twitter
1.6.0
@@ -164,14 +164,14 @@
3.4.1
3.2.2
- 2.12.8
+ 2.12.10
2.12
--diff --test
true
1.9.13
- 2.9.9
- 2.9.9.3
+ 2.9.10
+ 2.9.10
1.1.7.3
1.1.2
1.10
@@ -240,7 +240,7 @@
-->
${session.executionRootDirectory}
- 512m
+ 1g
@@ -620,7 +620,7 @@
com.github.luben
zstd-jni
- 1.4.2-1
+ 1.4.3-1
com.clearspring.analytics
@@ -786,14 +786,8 @@
org.scalanlp
breeze_${scala.binary.version}
- 0.13.2
+ 1.0
-
-
- junit
- junit
-
org.apache.commons
commons-math3
@@ -839,7 +833,7 @@
org.scala-lang.modules
scala-parser-combinators_${scala.binary.version}
- 1.1.0
+ 1.1.2
jline
@@ -849,7 +843,7 @@
org.scalatest
scalatest_${scala.binary.version}
- 3.0.5
+ 3.0.8
test
@@ -867,7 +861,7 @@
org.scalacheck
scalacheck_${scala.binary.version}
- 1.13.5
+ 1.14.2
test
@@ -1343,6 +1337,10 @@
io.netty
netty
+
+ com.github.spotbugs
+ spotbugs-annotations
+
@@ -2002,75 +2000,6 @@
-
- ${hive.group}
- hive-contrib
- ${hive.version}
- test
-
-
- ${hive.group}
- hive-exec
-
-
- ${hive.group}
- hive-serde
-
-
- ${hive.group}
- hive-shims
-
-
- commons-codec
- commons-codec
-
-
- org.slf4j
- slf4j-api
-
-
-
-
- ${hive.group}.hcatalog
- hive-hcatalog-core
- ${hive.version}
- test
-
-
- ${hive.group}
- hive-exec
-
-
- ${hive.group}
- hive-metastore
-
-
- ${hive.group}
- hive-cli
-
-
- ${hive.group}
- hive-common
-
-
- com.google.guava
- guava
-
-
- org.slf4j
- slf4j-api
-
-
- org.codehaus.jackson
- jackson-mapper-asl
-
-
- org.apache.hadoop
- *
-
-
-
-
org.apache.orc
orc-core
@@ -2287,6 +2216,17 @@
+
+ enforce-no-duplicate-dependencies
+
+ enforce
+
+
+
+
+
+
+
@@ -2974,7 +2914,6 @@
3.2.0
2.13.0
- 3.4.13
org.apache.hive
core
${hive23.version}
@@ -3053,6 +2992,19 @@
scala-2.12
+
+
+ scala-2.13
+
+
+
+ org.scala-lang.modules
+ scala-parallel-collections_${scala.binary.version}
+ 0.2.0
+
+
+
+
-
- ${hive.group}
- hive-contrib
-
-
- ${hive.group}.hcatalog
- hive-hcatalog-core
-
org.eclipse.jetty
jetty-server
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
index 36d4ac095e10c..9517a599be633 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
@@ -72,7 +72,7 @@ object HiveThriftServer2 extends Logging {
server
}
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
// If the arguments contains "-h" or "--help", print out the usage and exit.
if (args.contains("-h") || args.contains("--help")) {
HiveServer2.main(args)
@@ -303,7 +303,7 @@ private[hive] class HiveThriftServer2(sqlContext: SQLContext)
// started, and then once only.
private val started = new AtomicBoolean(false)
- override def init(hiveConf: HiveConf) {
+ override def init(hiveConf: HiveConf): Unit = {
val sparkSqlCliService = new SparkSQLCLIService(this, sqlContext)
setSuperField(this, "cliService", sparkSqlCliService)
addService(sparkSqlCliService)
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala
index 599294dfbb7d7..a4024be67ac9c 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala
@@ -18,11 +18,11 @@
package org.apache.spark.sql.hive.thriftserver
private[hive] object ReflectionUtils {
- def setSuperField(obj : Object, fieldName: String, fieldValue: Object) {
+ def setSuperField(obj : Object, fieldName: String, fieldValue: Object): Unit = {
setAncestorField(obj, 1, fieldName, fieldValue)
}
- def setAncestorField(obj: AnyRef, level: Int, fieldName: String, fieldValue: AnyRef) {
+ def setAncestorField(obj: AnyRef, level: Int, fieldName: String, fieldValue: AnyRef): Unit = {
val ancestor = Iterator.iterate[Class[_]](obj.getClass)(_.getSuperclass).drop(level).next()
val field = ancestor.getDeclaredField(fieldName)
field.setAccessible(true)
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
index 69e85484ccf8e..9ca6c39d016ba 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
@@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.execution.command.SetCommand
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.{Utils => SparkUtils}
private[hive] class SparkExecuteStatementOperation(
@@ -77,7 +78,7 @@ private[hive] class SparkExecuteStatementOperation(
HiveThriftServer2.listener.onOperationClosed(statementId)
}
- def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) {
+ def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int): Unit = {
dataTypes(ordinal) match {
case StringType =>
to += from.getString(ordinal)
@@ -103,6 +104,8 @@ private[hive] class SparkExecuteStatementOperation(
to += from.getAs[Timestamp](ordinal)
case BinaryType =>
to += from.getAs[Array[Byte]](ordinal)
+ case CalendarIntervalType =>
+ to += HiveResult.toHiveString((from.getAs[CalendarInterval](ordinal), CalendarIntervalType))
case _: ArrayType | _: StructType | _: MapType | _: UserDefinedType[_] =>
val hiveString = HiveResult.toHiveString((from.get(ordinal), dataTypes(ordinal)))
to += hiveString
@@ -264,6 +267,13 @@ private[hive] class SparkExecuteStatementOperation(
// Actually do need to catch Throwable as some failures don't inherit from Exception and
// HiveServer will silently swallow them.
case e: Throwable =>
+ // When cancel() or close() is called very quickly after the query is started,
+ // then they may both call cleanup() before Spark Jobs are started. But before background
+ // task interrupted, it may have start some spark job, so we need to cancel again to
+ // make sure job was cancelled when background thread was interrupted
+ if (statementId != null) {
+ sqlContext.sparkContext.cancelJobGroup(statementId)
+ }
val currentState = getStatus().getState()
if (currentState.isTerminal) {
// This may happen if the execution was cancelled, and then closed from another thread.
@@ -300,7 +310,7 @@ private[hive] class SparkExecuteStatementOperation(
}
}
- private def cleanup(state: OperationState) {
+ private def cleanup(state: OperationState): Unit = {
setState(state)
if (runInBackground) {
val backgroundHandle = getBackgroundHandle()
@@ -331,7 +341,11 @@ private[hive] class SparkExecuteStatementOperation(
object SparkExecuteStatementOperation {
def getTableSchema(structType: StructType): TableSchema = {
val schema = structType.map { field =>
- val attrTypeString = if (field.dataType == NullType) "void" else field.dataType.catalogString
+ val attrTypeString = field.dataType match {
+ case NullType => "void"
+ case CalendarIntervalType => StringType.catalogString
+ case other => other.catalogString
+ }
new FieldSchema(field.name, attrTypeString, field.getComment.getOrElse(""))
}
new TableSchema(schema.asJava)
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala
new file mode 100644
index 0000000000000..7a6a8c59b7216
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import java.util.UUID
+
+import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType
+import org.apache.hive.service.cli.{HiveSQLException, OperationState}
+import org.apache.hive.service.cli.operation.GetTypeInfoOperation
+import org.apache.hive.service.cli.session.HiveSession
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.util.{Utils => SparkUtils}
+
+/**
+ * Spark's own GetTypeInfoOperation
+ *
+ * @param sqlContext SQLContext to use
+ * @param parentSession a HiveSession from SessionManager
+ */
+private[hive] class SparkGetTypeInfoOperation(
+ sqlContext: SQLContext,
+ parentSession: HiveSession)
+ extends GetTypeInfoOperation(parentSession) with Logging {
+
+ private var statementId: String = _
+
+ override def close(): Unit = {
+ super.close()
+ HiveThriftServer2.listener.onOperationClosed(statementId)
+ }
+
+ override def runInternal(): Unit = {
+ statementId = UUID.randomUUID().toString
+ val logMsg = "Listing type info"
+ logInfo(s"$logMsg with $statementId")
+ setState(OperationState.RUNNING)
+ // Always use the latest class loader provided by executionHive's state.
+ val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader
+ Thread.currentThread().setContextClassLoader(executionHiveClassLoader)
+
+ if (isAuthV2Enabled) {
+ authorizeMetaGets(HiveOperationType.GET_TYPEINFO, null)
+ }
+
+ HiveThriftServer2.listener.onStatementStart(
+ statementId,
+ parentSession.getSessionHandle.getSessionId.toString,
+ logMsg,
+ statementId,
+ parentSession.getUsername)
+
+ try {
+ ThriftserverShimUtils.supportedType().foreach(typeInfo => {
+ val rowData = Array[AnyRef](
+ typeInfo.getName, // TYPE_NAME
+ typeInfo.toJavaSQLType.asInstanceOf[AnyRef], // DATA_TYPE
+ typeInfo.getMaxPrecision.asInstanceOf[AnyRef], // PRECISION
+ typeInfo.getLiteralPrefix, // LITERAL_PREFIX
+ typeInfo.getLiteralSuffix, // LITERAL_SUFFIX
+ typeInfo.getCreateParams, // CREATE_PARAMS
+ typeInfo.getNullable.asInstanceOf[AnyRef], // NULLABLE
+ typeInfo.isCaseSensitive.asInstanceOf[AnyRef], // CASE_SENSITIVE
+ typeInfo.getSearchable.asInstanceOf[AnyRef], // SEARCHABLE
+ typeInfo.isUnsignedAttribute.asInstanceOf[AnyRef], // UNSIGNED_ATTRIBUTE
+ typeInfo.isFixedPrecScale.asInstanceOf[AnyRef], // FIXED_PREC_SCALE
+ typeInfo.isAutoIncrement.asInstanceOf[AnyRef], // AUTO_INCREMENT
+ typeInfo.getLocalizedName, // LOCAL_TYPE_NAME
+ typeInfo.getMinimumScale.asInstanceOf[AnyRef], // MINIMUM_SCALE
+ typeInfo.getMaximumScale.asInstanceOf[AnyRef], // MAXIMUM_SCALE
+ null, // SQL_DATA_TYPE, unused
+ null, // SQL_DATETIME_SUB, unused
+ typeInfo.getNumPrecRadix // NUM_PREC_RADIX
+ )
+ rowSet.addRow(rowData)
+ })
+ setState(OperationState.FINISHED)
+ } catch {
+ case e: HiveSQLException =>
+ setState(OperationState.ERROR)
+ HiveThriftServer2.listener.onStatementError(
+ statementId, e.getMessage, SparkUtils.exceptionString(e))
+ throw e
+ }
+ HiveThriftServer2.listener.onStatementFinish(statementId)
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
index b9614d49eadbd..e3efa2d3ae8c9 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
@@ -63,7 +63,7 @@ private[hive] object SparkSQLCLIDriver extends Logging {
* a signal handler will invoke this registered callback if a Ctrl+C signal is detected while
* a command is being processed by the current thread.
*/
- def installSignalHandler() {
+ def installSignalHandler(): Unit = {
HiveInterruptUtils.add(() => {
// Handle remote execution mode
if (SparkSQLEnv.sparkContext != null) {
@@ -77,7 +77,7 @@ private[hive] object SparkSQLCLIDriver extends Logging {
})
}
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val oproc = new OptionsProcessor()
if (!oproc.process_stage1(args)) {
System.exit(1)
@@ -111,6 +111,11 @@ private[hive] object SparkSQLCLIDriver extends Logging {
// Set all properties specified via command line.
val conf: HiveConf = sessionState.getConf
+ // Hive 2.0.0 onwards HiveConf.getClassLoader returns the UDFClassLoader (created by Hive).
+ // Because of this spark cannot find the jars as class loader got changed
+ // Hive changed the class loader because of HIVE-11878, so it is required to use old
+ // classLoader as sparks loaded all the jars in this classLoader
+ conf.setClassLoader(Thread.currentThread().getContextClassLoader)
sessionState.cmdProperties.entrySet().asScala.foreach { item =>
val key = item.getKey.toString
val value = item.getValue.toString
@@ -133,20 +138,7 @@ private[hive] object SparkSQLCLIDriver extends Logging {
// Clean up after we exit
ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() }
- val remoteMode = isRemoteMode(sessionState)
- // "-h" option has been passed, so connect to Hive thrift server.
- if (!remoteMode) {
- // Hadoop-20 and above - we need to augment classpath using hiveconf
- // components.
- // See also: code in ExecDriver.java
- var loader = conf.getClassLoader
- val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS)
- if (StringUtils.isNotBlank(auxJars)) {
- loader = ThriftserverShimUtils.addToClassPath(loader, StringUtils.split(auxJars, ","))
- }
- conf.setClassLoader(loader)
- Thread.currentThread().setContextClassLoader(loader)
- } else {
+ if (isRemoteMode(sessionState)) {
// Hive 1.2 + not supported in CLI
throw new RuntimeException("Remote operations not supported")
}
@@ -164,6 +156,22 @@ private[hive] object SparkSQLCLIDriver extends Logging {
val cli = new SparkSQLCLIDriver
cli.setHiveVariables(oproc.getHiveVariables)
+ // In SparkSQL CLI, we may want to use jars augmented by hiveconf
+ // hive.aux.jars.path, here we add jars augmented by hiveconf to
+ // Spark's SessionResourceLoader to obtain these jars.
+ val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS)
+ if (StringUtils.isNotBlank(auxJars)) {
+ val resourceLoader = SparkSQLEnv.sqlContext.sessionState.resourceLoader
+ StringUtils.split(auxJars, ",").foreach(resourceLoader.addJar(_))
+ }
+
+ // The class loader of CliSessionState's conf is current main thread's class loader
+ // used to load jars passed by --jars. One class loader used by AddJarCommand is
+ // sharedState.jarClassLoader which contain jar path passed by --jars in main thread.
+ // We set CliSessionState's conf class loader to sharedState.jarClassLoader.
+ // Thus we can load all jars passed by --jars and AddJarCommand.
+ sessionState.getConf.setClassLoader(SparkSQLEnv.sqlContext.sharedState.jarClassLoader)
+
// TODO work around for set the log output to console, because the HiveContext
// will set the output into an invalid buffer.
sessionState.in = System.in
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
index c32d908ad1bba..1644ecb2453be 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
@@ -43,7 +43,7 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLC
extends CLIService(hiveServer)
with ReflectedCompositeService {
- override def init(hiveConf: HiveConf) {
+ override def init(hiveConf: HiveConf): Unit = {
setSuperField(this, "hiveConf", hiveConf)
val sparkSqlSessionManager = new SparkSQLSessionManager(hiveServer, sqlContext)
@@ -105,7 +105,7 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLC
}
private[thriftserver] trait ReflectedCompositeService { this: AbstractService =>
- def initCompositeService(hiveConf: HiveConf) {
+ def initCompositeService(hiveConf: HiveConf): Unit = {
// Emulating `CompositeService.init(hiveConf)`
val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList")
serviceList.asScala.foreach(_.init(hiveConf))
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
index 960fdd11db15d..362ac362e9718 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
@@ -94,7 +94,7 @@ private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlCont
override def getSchema: Schema = tableSchema
- override def destroy() {
+ override def destroy(): Unit = {
super.destroy()
hiveResponse = null
tableSchema = null
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
index 674da18ca1803..2fda9d0a4f60f 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
@@ -33,7 +33,7 @@ private[hive] object SparkSQLEnv extends Logging {
var sqlContext: SQLContext = _
var sparkContext: SparkContext = _
- def init() {
+ def init(): Unit = {
if (sqlContext == null) {
val sparkConf = new SparkConf(loadDefaults = true)
// If user doesn't specify the appName, we want to get [SparkSQL::localHostName] instead of
@@ -60,7 +60,7 @@ private[hive] object SparkSQLEnv extends Logging {
}
/** Cleans up and shuts down the Spark SQL environments. */
- def stop() {
+ def stop(): Unit = {
logDebug("Shutting down Spark SQL Environment")
// Stop the SparkContext
if (SparkSQLEnv.sparkContext != null) {
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
index 13055e0ae1394..c4248bfde38cc 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
@@ -38,7 +38,7 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext:
private lazy val sparkSqlOperationManager = new SparkSQLOperationManager()
- override def init(hiveConf: HiveConf) {
+ override def init(hiveConf: HiveConf): Unit = {
setSuperField(this, "operationManager", sparkSqlOperationManager)
super.init(hiveConf)
}
@@ -63,6 +63,9 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext:
sqlContext.newSession()
}
ctx.setConf(HiveUtils.FAKE_HIVE_VERSION.key, HiveUtils.builtinHiveVersion)
+ val hiveSessionState = session.getSessionState
+ setConfMap(ctx, hiveSessionState.getOverriddenConfigurations)
+ setConfMap(ctx, hiveSessionState.getHiveVariables)
if (sessionConf != null && sessionConf.containsKey("use:database")) {
ctx.sql(s"use ${sessionConf.get("use:database")}")
}
@@ -70,10 +73,18 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext:
sessionHandle
}
- override def closeSession(sessionHandle: SessionHandle) {
+ override def closeSession(sessionHandle: SessionHandle): Unit = {
HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString)
super.closeSession(sessionHandle)
sparkSqlOperationManager.sessionToActivePool.remove(sessionHandle)
sparkSqlOperationManager.sessionToContexts.remove(sessionHandle)
}
+
+ def setConfMap(conf: SQLContext, confMap: java.util.Map[String, String]): Unit = {
+ val iterator = confMap.entrySet().iterator()
+ while (iterator.hasNext) {
+ val kv = iterator.next()
+ conf.setConf(kv.getKey, kv.getValue)
+ }
+ }
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
index 35f92547e7815..3396560f43502 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -28,7 +28,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.thriftserver._
-import org.apache.spark.sql.internal.SQLConf
/**
* Executes queries using Spark SQL, and maintains a list of handles to active queries.
@@ -51,9 +50,6 @@ private[thriftserver] class SparkSQLOperationManager()
require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" +
s" initialized or had already closed.")
val conf = sqlContext.sessionState.conf
- val hiveSessionState = parentSession.getSessionState
- setConfMap(conf, hiveSessionState.getOverriddenConfigurations)
- setConfMap(conf, hiveSessionState.getHiveVariables)
val runInBackground = async && conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC)
val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay,
runInBackground)(sqlContext, sessionToActivePool)
@@ -145,11 +141,14 @@ private[thriftserver] class SparkSQLOperationManager()
operation
}
- def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = {
- val iterator = confMap.entrySet().iterator()
- while (iterator.hasNext) {
- val kv = iterator.next()
- conf.setConfString(kv.getKey, kv.getValue)
- }
+ override def newGetTypeInfoOperation(
+ parentSession: HiveSession): GetTypeInfoOperation = synchronized {
+ val sqlContext = sessionToContexts.get(parentSession.getSessionHandle)
+ require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" +
+ " initialized or had already closed.")
+ val operation = new SparkGetTypeInfoOperation(sqlContext, parentSession)
+ handleToOperation.put(operation.getHandle, operation)
+ logDebug(s"Created GetTypeInfoOperation with session=$parentSession.")
+ operation
}
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala
index 261e8fc912eb9..4056be4769d21 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala
@@ -26,6 +26,7 @@ import org.apache.commons.text.StringEscapeUtils
import org.apache.spark.internal.Logging
import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, ExecutionState, SessionInfo}
+import org.apache.spark.sql.hive.thriftserver.ui.ToolTips._
import org.apache.spark.ui._
import org.apache.spark.ui.UIUtils._
@@ -72,6 +73,10 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage(""
val table = if (numStatement > 0) {
val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Close Time",
"Execution Time", "Duration", "Statement", "State", "Detail")
+ val tooltips = Seq(None, None, None, None, Some(THRIFT_SERVER_FINISH_TIME),
+ Some(THRIFT_SERVER_CLOSE_TIME), Some(THRIFT_SERVER_EXECUTION),
+ Some(THRIFT_SERVER_DURATION), None, None, None)
+ assert(headerRow.length == tooltips.length)
val dataRows = listener.getExecutionList.sortBy(_.startTimestamp).reverse
def generateDataRow(info: ExecutionInfo): Seq[Node] = {
@@ -91,8 +96,10 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage(""
{formatDate(info.startTimestamp)}
{if (info.finishTimestamp > 0) formatDate(info.finishTimestamp)}
{if (info.closeTimestamp > 0) formatDate(info.closeTimestamp)}
- {formatDurationOption(Some(info.totalTime(info.finishTimestamp)))}
- {formatDurationOption(Some(info.totalTime(info.closeTimestamp)))}
+
+ {formatDurationOption(Some(info.totalTime(info.finishTimestamp)))}
+
+ {formatDurationOption(Some(info.totalTime(info.closeTimestamp)))}
{info.statement}
{info.state}
{errorMessageCell(detail)}
@@ -100,7 +107,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage(""
}
Some(UIUtils.listingTable(headerRow, generateDataRow,
- dataRows, false, None, Seq(null), false))
+ dataRows, false, None, Seq(null), false, tooltipHeaders = tooltips))
} else {
None
}
@@ -157,7 +164,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage(""
{session.sessionId}
{formatDate(session.startTimestamp)}
{if (session.finishTimestamp > 0) formatDate(session.finishTimestamp)}
- {formatDurationOption(Some(session.totalTime))}
+
+ {formatDurationOption(Some(session.totalTime))}
{session.totalExecution.toString}
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala
index 81df1304085e8..0aa0a2b8335d8 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala
@@ -26,6 +26,7 @@ import org.apache.commons.text.StringEscapeUtils
import org.apache.spark.internal.Logging
import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, ExecutionState}
+import org.apache.spark.sql.hive.thriftserver.ui.ToolTips._
import org.apache.spark.ui._
import org.apache.spark.ui.UIUtils._
@@ -81,6 +82,10 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab)
val table = if (numStatement > 0) {
val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Close Time",
"Execution Time", "Duration", "Statement", "State", "Detail")
+ val tooltips = Seq(None, None, None, None, Some(THRIFT_SERVER_FINISH_TIME),
+ Some(THRIFT_SERVER_CLOSE_TIME), Some(THRIFT_SERVER_EXECUTION),
+ Some(THRIFT_SERVER_DURATION), None, None, None)
+ assert(headerRow.length == tooltips.length)
val dataRows = executionList.sortBy(_.startTimestamp).reverse
def generateDataRow(info: ExecutionInfo): Seq[Node] = {
@@ -98,10 +103,14 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab)
{info.groupId}
{formatDate(info.startTimestamp)}
- {formatDate(info.finishTimestamp)}
- {formatDate(info.closeTimestamp)}
- {formatDurationOption(Some(info.totalTime(info.finishTimestamp)))}
- {formatDurationOption(Some(info.totalTime(info.closeTimestamp)))}
+ {if (info.finishTimestamp > 0) formatDate(info.finishTimestamp)}
+ {if (info.closeTimestamp > 0) formatDate(info.closeTimestamp)}
+
+ {formatDurationOption(Some(info.totalTime(info.finishTimestamp)))}
+
+
+ {formatDurationOption(Some(info.totalTime(info.closeTimestamp)))}
+
{info.statement}
{info.state}
{errorMessageCell(detail)}
@@ -109,7 +118,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab)
}
Some(UIUtils.listingTable(headerRow, generateDataRow,
- dataRows, false, None, Seq(null), false))
+ dataRows, false, None, Seq(null), false, tooltipHeaders = tooltips))
} else {
None
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala
index db2066009b351..8efb2c3311cfe 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala
@@ -39,7 +39,7 @@ private[thriftserver] class ThriftServerTab(sparkContext: SparkContext)
attachPage(new ThriftServerSessionPage(this))
parent.attachTab(this)
- def detach() {
+ def detach(): Unit = {
getSparkUI(sparkContext).detachTab(this)
}
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ToolTips.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ToolTips.scala
new file mode 100644
index 0000000000000..1990b8f2d3285
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ToolTips.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver.ui
+
+private[ui] object ToolTips {
+ val THRIFT_SERVER_FINISH_TIME =
+ "Execution finish time, before fetching the results"
+
+ val THRIFT_SERVER_CLOSE_TIME =
+ "Operation close time after fetching the results"
+
+ val THRIFT_SERVER_EXECUTION =
+ "Difference between start time and finish time"
+
+ val THRIFT_SERVER_DURATION =
+ "Difference between start time and close time"
+}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
index 6e042ac41d9da..f3063675a79f7 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -27,12 +27,11 @@ import scala.concurrent.Promise
import scala.concurrent.duration._
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.hive.test.HiveTestUtils
+import org.apache.spark.sql.hive.test.HiveTestJars
import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer
import org.apache.spark.util.{ThreadUtils, Utils}
@@ -202,7 +201,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {
}
test("Commands using SerDe provided in --jars") {
- val jarFile = HiveTestUtils.getHiveHcatalogCoreJar.getCanonicalPath
+ val jarFile = HiveTestJars.getHiveHcatalogCoreJar().getCanonicalPath
val dataFilePath =
Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt")
@@ -218,8 +217,8 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {
-> "",
"INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;"
-> "",
- "SELECT count(key) FROM t1;"
- -> "5",
+ "SELECT collect_list(array(val)) FROM t1;"
+ -> """[["val_238"],["val_86"],["val_311"],["val_27"],["val_165"]]""",
"DROP TABLE t1;"
-> "",
"DROP TABLE sourceTable;"
@@ -227,6 +226,32 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {
)
}
+ test("SPARK-29022: Commands using SerDe provided in --hive.aux.jars.path") {
+ val dataFilePath =
+ Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt")
+ val hiveContribJar = HiveTestJars.getHiveHcatalogCoreJar().getCanonicalPath
+ runCliWithin(
+ 3.minute,
+ Seq("--conf", s"spark.hadoop.${ConfVars.HIVEAUXJARS}=$hiveContribJar"))(
+ """CREATE TABLE addJarWithHiveAux(key string, val string)
+ |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe';
+ """.stripMargin
+ -> "",
+ "CREATE TABLE sourceTableForWithHiveAux (key INT, val STRING);"
+ -> "",
+ s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTableForWithHiveAux;"
+ -> "",
+ "INSERT INTO TABLE addJarWithHiveAux SELECT key, val FROM sourceTableForWithHiveAux;"
+ -> "",
+ "SELECT collect_list(array(val)) FROM addJarWithHiveAux;"
+ -> """[["val_238"],["val_86"],["val_311"],["val_27"],["val_165"]]""",
+ "DROP TABLE addJarWithHiveAux;"
+ -> "",
+ "DROP TABLE sourceTableForWithHiveAux;"
+ -> ""
+ )
+ }
+
test("SPARK-11188 Analysis error reporting") {
runCliWithin(timeout = 2.minute,
errorResponses = Seq("AnalysisException"))(
@@ -297,12 +322,66 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {
}
test("Support hive.aux.jars.path") {
- val hiveContribJar = HiveTestUtils.getHiveContribJar.getCanonicalPath
+ val hiveContribJar = HiveTestJars.getHiveContribJar().getCanonicalPath
runCliWithin(
1.minute,
Seq("--conf", s"spark.hadoop.${ConfVars.HIVEAUXJARS}=$hiveContribJar"))(
- s"CREATE TEMPORARY FUNCTION example_max AS '${classOf[UDAFExampleMax].getName}';" -> "",
- "SELECT example_max(1);" -> "1"
+ "CREATE TEMPORARY FUNCTION example_format AS " +
+ "'org.apache.hadoop.hive.contrib.udf.example.UDFExampleFormat';" -> "",
+ "SELECT example_format('%o', 93);" -> "135"
+ )
+ }
+
+ test("SPARK-28840 test --jars command") {
+ val jarFile = new File("../../sql/hive/src/test/resources/SPARK-21101-1.0.jar").getCanonicalPath
+ runCliWithin(
+ 1.minute,
+ Seq("--jars", s"$jarFile"))(
+ "CREATE TEMPORARY FUNCTION testjar AS" +
+ " 'org.apache.spark.sql.hive.execution.UDTFStack';" -> "",
+ "SELECT testjar(1,'TEST-SPARK-TEST-jar', 28840);" -> "TEST-SPARK-TEST-jar\t28840"
+ )
+ }
+
+ test("SPARK-28840 test --jars and hive.aux.jars.path command") {
+ val jarFile = new File("../../sql/hive/src/test/resources/SPARK-21101-1.0.jar").getCanonicalPath
+ val hiveContribJar = HiveTestJars.getHiveContribJar().getCanonicalPath
+ runCliWithin(
+ 1.minute,
+ Seq("--jars", s"$jarFile", "--conf",
+ s"spark.hadoop.${ConfVars.HIVEAUXJARS}=$hiveContribJar"))(
+ "CREATE TEMPORARY FUNCTION testjar AS" +
+ " 'org.apache.spark.sql.hive.execution.UDTFStack';" -> "",
+ "SELECT testjar(1,'TEST-SPARK-TEST-jar', 28840);" -> "TEST-SPARK-TEST-jar\t28840",
+ "CREATE TEMPORARY FUNCTION example_max AS " +
+ "'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax';" -> "",
+ "SELECT concat_ws(',', 'First', example_max(1234321), 'Third');" -> "First,1234321,Third"
+ )
+ }
+
+ test("SPARK-29022 Commands using SerDe provided in ADD JAR sql") {
+ val dataFilePath =
+ Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt")
+ val hiveContribJar = HiveTestJars.getHiveHcatalogCoreJar().getCanonicalPath
+ runCliWithin(
+ 3.minute)(
+ s"ADD JAR ${hiveContribJar};" -> "",
+ """CREATE TABLE addJarWithSQL(key string, val string)
+ |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe';
+ """.stripMargin
+ -> "",
+ "CREATE TABLE sourceTableForWithSQL(key INT, val STRING);"
+ -> "",
+ s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTableForWithSQL;"
+ -> "",
+ "INSERT INTO TABLE addJarWithSQL SELECT key, val FROM sourceTableForWithSQL;"
+ -> "",
+ "SELECT collect_list(array(val)) FROM addJarWithSQL;"
+ -> """[["val_238"],["val_86"],["val_311"],["val_27"],["val_165"]]""",
+ "DROP TABLE addJarWithSQL;"
+ -> "",
+ "DROP TABLE sourceTableForWithSQL;"
+ -> ""
)
}
}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
index b7185db2f2ae7..8a5526ea780ef 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
@@ -43,7 +43,7 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.hive.HiveUtils
-import org.apache.spark.sql.hive.test.HiveTestUtils
+import org.apache.spark.sql.hive.test.HiveTestJars
import org.apache.spark.sql.internal.StaticSQLConf.HIVE_THRIFT_SERVER_SINGLESESSION
import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer
import org.apache.spark.util.{ThreadUtils, Utils}
@@ -144,10 +144,17 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
def executeTest(hiveList: String): Unit = {
hiveList.split(";").foreach{ m =>
val kv = m.split("=")
- // select "${a}"; ---> avalue
- val resultSet = statement.executeQuery("select \"${" + kv(0) + "}\"")
+ val k = kv(0)
+ val v = kv(1)
+ val modValue = s"${v}_MOD_VALUE"
+ // select '${a}'; ---> avalue
+ val resultSet = statement.executeQuery(s"select '$${$k}'")
resultSet.next()
- assert(resultSet.getString(1) === kv(1))
+ assert(resultSet.getString(1) === v)
+ statement.executeQuery(s"set $k=$modValue")
+ val modResultSet = statement.executeQuery(s"select '$${$k}'")
+ modResultSet.next()
+ assert(modResultSet.getString(1) === s"$modValue")
}
}
}
@@ -485,7 +492,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
withMultipleConnectionJdbcStatement("smallKV", "addJar")(
{
statement =>
- val jarFile = HiveTestUtils.getHiveHcatalogCoreJar.getCanonicalPath
+ val jarFile = HiveTestJars.getHiveHcatalogCoreJar().getCanonicalPath
statement.executeQuery(s"ADD JAR $jarFile")
},
@@ -662,6 +669,21 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
assert(rs.getBigDecimal(1) === new java.math.BigDecimal("1.000000000000000000"))
}
}
+
+ test("Support interval type") {
+ withJdbcStatement() { statement =>
+ val rs = statement.executeQuery("SELECT interval 3 months 1 hours")
+ assert(rs.next())
+ assert(rs.getString(1) === "interval 3 months 1 hours")
+ }
+ // Invalid interval value
+ withJdbcStatement() { statement =>
+ val e = intercept[SQLException] {
+ statement.executeQuery("SELECT interval 3 months 1 hou")
+ }
+ assert(e.getMessage.contains("org.apache.spark.sql.catalyst.parser.ParseException"))
+ }
+ }
}
class SingleSessionSuite extends HiveThriftJdbcTest {
@@ -820,7 +842,7 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test {
s"jdbc:hive2://localhost:$serverPort/?${hiveConfList}#${hiveVarList}"
}
- def withMultipleConnectionJdbcStatement(tableNames: String*)(fs: (Statement => Unit)*) {
+ def withMultipleConnectionJdbcStatement(tableNames: String*)(fs: (Statement => Unit)*): Unit = {
val user = System.getProperty("user.name")
val connections = fs.map { _ => DriverManager.getConnection(jdbcUri, user, "") }
val statements = connections.map(_.createStatement())
@@ -841,7 +863,7 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test {
}
}
- def withDatabase(dbNames: String*)(fs: (Statement => Unit)*) {
+ def withDatabase(dbNames: String*)(fs: (Statement => Unit)*): Unit = {
val user = System.getProperty("user.name")
val connections = fs.map { _ => DriverManager.getConnection(jdbcUri, user, "") }
val statements = connections.map(_.createStatement())
@@ -857,7 +879,7 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test {
}
}
- def withJdbcStatement(tableNames: String*)(f: Statement => Unit) {
+ def withJdbcStatement(tableNames: String*)(f: Statement => Unit): Unit = {
withMultipleConnectionJdbcStatement(tableNames: _*)(f)
}
}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala
index 21870ffd463ec..f7ee3e0a46cd1 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala
@@ -231,4 +231,20 @@ class SparkMetadataOperationSuite extends HiveThriftJdbcTest {
assert(!rs.next())
}
}
+
+ test("GetTypeInfo Thrift API") {
+ def checkResult(rs: ResultSet, typeNames: Seq[String]): Unit = {
+ for (i <- typeNames.indices) {
+ assert(rs.next())
+ assert(rs.getString("TYPE_NAME") === typeNames(i))
+ }
+ // Make sure there are no more elements
+ assert(!rs.next())
+ }
+
+ withJdbcStatement() { statement =>
+ val metaData = statement.getConnection.getMetaData
+ checkResult(metaData.getTypeInfo, ThriftserverShimUtils.supportedType().map(_.getName))
+ }
+ }
}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkThriftServerProtocolVersionsSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkThriftServerProtocolVersionsSuite.scala
index f198372a4c998..10ec1ee168303 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkThriftServerProtocolVersionsSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkThriftServerProtocolVersionsSuite.scala
@@ -261,10 +261,10 @@ class SparkThriftServerProtocolVersionsSuite extends HiveThriftJdbcTest {
}
}
- // We do not fully support interval type
- ignore(s"$version get interval type") {
+ test(s"$version get interval type") {
testExecuteStatementWithProtocolVersion(version, "SELECT interval '1' year '2' day") { rs =>
assert(rs.next())
+ assert(rs.getString(1) === "interval 1 years 2 days")
}
}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala
index 1f7b3feae47b5..613c1655727bb 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala
@@ -18,19 +18,21 @@
package org.apache.spark.sql.hive.thriftserver
import java.io.File
-import java.sql.{DriverManager, SQLException, Statement, Timestamp}
-import java.util.Locale
+import java.sql.{DriverManager, Statement, Timestamp}
+import java.util.{Locale, MissingFormatArgumentException}
import scala.util.{Random, Try}
import scala.util.control.NonFatal
+import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.apache.hive.service.cli.HiveSQLException
-import org.scalatest.Ignore
+import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{AnalysisException, SQLQueryTestSuite}
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.util.fileToString
import org.apache.spark.sql.execution.HiveResult
+import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -43,12 +45,12 @@ import org.apache.spark.sql.types._
* 2. Support DESC command.
* 3. Support SHOW command.
*/
-@Ignore
class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
private var hiveServer2: HiveThriftServer2 = _
- override def beforeEach(): Unit = {
+ override def beforeAll(): Unit = {
+ super.beforeAll()
// Chooses a random port between 10000 and 19999
var listeningPort = 10000 + Random.nextInt(10000)
@@ -65,36 +67,40 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
logInfo("HiveThriftServer2 started successfully")
}
- override def afterEach(): Unit = {
- hiveServer2.stop()
+ override def afterAll(): Unit = {
+ try {
+ hiveServer2.stop()
+ } finally {
+ super.afterAll()
+ }
}
+ override def sparkConf: SparkConf = super.sparkConf
+ // Hive Thrift server should not executes SQL queries in an asynchronous way
+ // because we may set session configuration.
+ .set(HiveUtils.HIVE_THRIFT_SERVER_ASYNC, false)
+
override val isTestWithConfigSets = false
/** List of test cases to ignore, in lower cases. */
override def blackList: Set[String] = Set(
"blacklist.sql", // Do NOT remove this one. It is here to test the blacklist functionality.
// Missing UDF
- "pgSQL/boolean.sql",
- "pgSQL/case.sql",
+ "postgreSQL/boolean.sql",
+ "postgreSQL/case.sql",
// SPARK-28624
"date.sql",
- // SPARK-28619
- "pgSQL/aggregates_part1.sql",
- "group-by.sql",
// SPARK-28620
- "pgSQL/float4.sql",
+ "postgreSQL/float4.sql",
// SPARK-28636
"decimalArithmeticOperations.sql",
"literals.sql",
"subquery/scalar-subquery/scalar-subquery-predicate.sql",
"subquery/in-subquery/in-limit.sql",
+ "subquery/in-subquery/in-group-by.sql",
"subquery/in-subquery/simple-in.sql",
"subquery/in-subquery/in-order-by.sql",
- "subquery/in-subquery/in-set-operations.sql",
- // SPARK-28637
- "cast.sql",
- "ansi/interval.sql"
+ "subquery/in-subquery/in-set-operations.sql"
)
override def runQueries(
@@ -110,8 +116,8 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
case _: PgSQLTest =>
// PostgreSQL enabled cartesian product by default.
statement.execute(s"SET ${SQLConf.CROSS_JOINS_ENABLED.key} = true")
- statement.execute(s"SET ${SQLConf.ANSI_SQL_PARSER.key} = true")
- statement.execute(s"SET ${SQLConf.PREFER_INTEGRAL_DIVISION.key} = true")
+ statement.execute(s"SET ${SQLConf.ANSI_ENABLED.key} = true")
+ statement.execute(s"SET ${SQLConf.DIALECT.key} = ${SQLConf.Dialect.POSTGRESQL.toString}")
case _ =>
}
@@ -166,19 +172,42 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
|| d.sql.toUpperCase(Locale.ROOT).startsWith("DESC\n")
|| d.sql.toUpperCase(Locale.ROOT).startsWith("DESCRIBE ")
|| d.sql.toUpperCase(Locale.ROOT).startsWith("DESCRIBE\n") =>
+
// Skip show command, see HiveResult.hiveResultString
case s if s.sql.toUpperCase(Locale.ROOT).startsWith("SHOW ")
|| s.sql.toUpperCase(Locale.ROOT).startsWith("SHOW\n") =>
- // AnalysisException should exactly match.
- // SQLException should not exactly match. We only assert the result contains Exception.
- case _ if output.output.startsWith(classOf[SQLException].getName) =>
+
+ case _ if output.output.startsWith(classOf[NoSuchTableException].getPackage.getName) =>
+ assert(expected.output.startsWith(classOf[NoSuchTableException].getPackage.getName),
+ s"Exception did not match for query #$i\n${expected.sql}, " +
+ s"expected: ${expected.output}, but got: ${output.output}")
+
+ case _ if output.output.startsWith(classOf[SparkException].getName) &&
+ output.output.contains("overflow") =>
+ assert(expected.output.contains(classOf[ArithmeticException].getName) &&
+ expected.output.contains("overflow"),
+ s"Exception did not match for query #$i\n${expected.sql}, " +
+ s"expected: ${expected.output}, but got: ${output.output}")
+
+ case _ if output.output.startsWith(classOf[RuntimeException].getName) =>
assert(expected.output.contains("Exception"),
s"Exception did not match for query #$i\n${expected.sql}, " +
s"expected: ${expected.output}, but got: ${output.output}")
- // HiveSQLException is usually a feature that our ThriftServer cannot support.
- // Please add SQL to blackList.
- case _ if output.output.startsWith(classOf[HiveSQLException].getName) =>
- assert(false, s"${output.output} for query #$i\n${expected.sql}")
+
+ case _ if output.output.startsWith(classOf[ArithmeticException].getName) &&
+ output.output.contains("causes overflow") =>
+ assert(expected.output.contains(classOf[ArithmeticException].getName) &&
+ expected.output.contains("causes overflow"),
+ s"Exception did not match for query #$i\n${expected.sql}, " +
+ s"expected: ${expected.output}, but got: ${output.output}")
+
+ case _ if output.output.startsWith(classOf[MissingFormatArgumentException].getName) &&
+ output.output.contains("Format specifier") =>
+ assert(expected.output.contains(classOf[MissingFormatArgumentException].getName) &&
+ expected.output.contains("Format specifier"),
+ s"Exception did not match for query #$i\n${expected.sql}, " +
+ s"expected: ${expected.output}, but got: ${output.output}")
+
case _ =>
assertResult(expected.output, s"Result did not match for query #$i\n${expected.sql}") {
output.output
@@ -209,7 +238,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udf")) {
Seq.empty
- } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}pgSQL")) {
+ } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}postgreSQL")) {
PgSQLTestCase(testCaseName, absPath, resultFile) :: Nil
} else {
RegularTestCase(testCaseName, absPath, resultFile) :: Nil
@@ -248,8 +277,9 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage
Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")).sorted
case NonFatal(e) =>
+ val rootCause = ExceptionUtils.getRootCause(e)
// If there is an exception, put the exception class followed by the message.
- Seq(e.getClass.getName, e.getMessage)
+ Seq(rootCause.getClass.getName, rootCause.getMessage)
}
}
@@ -260,7 +290,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
hiveServer2 = HiveThriftServer2.startWithContext(sqlContext)
}
- private def withJdbcStatement(fs: (Statement => Unit)*) {
+ private def withJdbcStatement(fs: (Statement => Unit)*): Unit = {
val user = System.getProperty("user.name")
val serverPort = hiveServer2.getHiveConf.get(ConfVars.HIVE_SERVER2_THRIFT_PORT.varname)
@@ -337,7 +367,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
upperCase.startsWith("SELECT ") || upperCase.startsWith("SELECT\n") ||
upperCase.startsWith("WITH ") || upperCase.startsWith("WITH\n") ||
upperCase.startsWith("VALUES ") || upperCase.startsWith("VALUES\n") ||
- // pgSQL/union.sql
+ // postgreSQL/union.sql
upperCase.startsWith("(")
}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala
index 47cf4f104d204..7f731f3d05e51 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala
@@ -24,8 +24,8 @@ import org.openqa.selenium.WebDriver
import org.openqa.selenium.htmlunit.HtmlUnitDriver
import org.scalatest.{BeforeAndAfterAll, Matchers}
import org.scalatest.concurrent.Eventually._
-import org.scalatest.selenium.WebBrowser
import org.scalatest.time.SpanSugar._
+import org.scalatestplus.selenium.WebBrowser
import org.apache.spark.ui.SparkUICssErrorHandler
diff --git a/sql/hive-thriftserver/v1.2.1/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java b/sql/hive-thriftserver/v1.2.1/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java
index 0f72071d7e7d1..3e81f8afbd85f 100644
--- a/sql/hive-thriftserver/v1.2.1/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java
+++ b/sql/hive-thriftserver/v1.2.1/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java
@@ -73,7 +73,7 @@ public class GetTypeInfoOperation extends MetadataOperation {
.addPrimitiveColumn("NUM_PREC_RADIX", Type.INT_TYPE,
"Usually 2 or 10");
- private final RowSet rowSet;
+ protected final RowSet rowSet;
protected GetTypeInfoOperation(HiveSession parentSession) {
super(parentSession, OperationType.GET_TYPE_INFO);
diff --git a/sql/hive-thriftserver/v1.2.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala b/sql/hive-thriftserver/v1.2.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala
index 87c0f8f6a571a..fbfc698ecb4bf 100644
--- a/sql/hive-thriftserver/v1.2.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala
+++ b/sql/hive-thriftserver/v1.2.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala
@@ -18,9 +18,9 @@
package org.apache.spark.sql.hive.thriftserver
import org.apache.commons.logging.LogFactory
-import org.apache.hadoop.hive.ql.exec.Utilities
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hive.service.cli.{RowSet, RowSetFactory, TableSchema, Type}
+import org.apache.hive.service.cli.Type._
import org.apache.hive.service.cli.thrift.TProtocolVersion._
/**
@@ -51,10 +51,12 @@ private[thriftserver] object ThriftserverShimUtils {
private[thriftserver] def toJavaSQLType(s: String): Int = Type.getType(s).toJavaSQLType
- private[thriftserver] def addToClassPath(
- loader: ClassLoader,
- auxJars: Array[String]): ClassLoader = {
- Utilities.addToClassPath(loader, auxJars)
+ private[thriftserver] def supportedType(): Seq[Type] = {
+ Seq(NULL_TYPE, BOOLEAN_TYPE, STRING_TYPE, BINARY_TYPE,
+ TINYINT_TYPE, SMALLINT_TYPE, INT_TYPE, BIGINT_TYPE,
+ FLOAT_TYPE, DOUBLE_TYPE, DECIMAL_TYPE,
+ DATE_TYPE, TIMESTAMP_TYPE,
+ ARRAY_TYPE, MAP_TYPE, STRUCT_TYPE)
}
private[thriftserver] val testedProtocolVersions = Seq(
diff --git a/sql/hive-thriftserver/v2.3.5/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java b/sql/hive-thriftserver/v2.3.5/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java
index 9612eb145638c..0f57a72e2a1ce 100644
--- a/sql/hive-thriftserver/v2.3.5/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java
+++ b/sql/hive-thriftserver/v2.3.5/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java
@@ -73,7 +73,7 @@ public class GetTypeInfoOperation extends MetadataOperation {
.addPrimitiveColumn("NUM_PREC_RADIX", Type.INT_TYPE,
"Usually 2 or 10");
- private final RowSet rowSet;
+ protected final RowSet rowSet;
protected GetTypeInfoOperation(HiveSession parentSession) {
super(parentSession, OperationType.GET_TYPE_INFO);
diff --git a/sql/hive-thriftserver/v2.3.5/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala b/sql/hive-thriftserver/v2.3.5/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala
index 124c9937c0fca..850382fe2bfd7 100644
--- a/sql/hive-thriftserver/v2.3.5/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala
+++ b/sql/hive-thriftserver/v2.3.5/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala
@@ -17,13 +17,9 @@
package org.apache.spark.sql.hive.thriftserver
-import java.security.AccessController
-
-import scala.collection.JavaConverters._
-
-import org.apache.hadoop.hive.ql.exec.AddToClassPathAction
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde2.thrift.Type
+import org.apache.hadoop.hive.serde2.thrift.Type._
import org.apache.hive.service.cli.{RowSet, RowSetFactory, TableSchema}
import org.apache.hive.service.rpc.thrift.TProtocolVersion._
import org.slf4j.LoggerFactory
@@ -56,11 +52,12 @@ private[thriftserver] object ThriftserverShimUtils {
private[thriftserver] def toJavaSQLType(s: String): Int = Type.getType(s).toJavaSQLType
- private[thriftserver] def addToClassPath(
- loader: ClassLoader,
- auxJars: Array[String]): ClassLoader = {
- val addAction = new AddToClassPathAction(loader, auxJars.toList.asJava)
- AccessController.doPrivileged(addAction)
+ private[thriftserver] def supportedType(): Seq[Type] = {
+ Seq(NULL_TYPE, BOOLEAN_TYPE, STRING_TYPE, BINARY_TYPE,
+ TINYINT_TYPE, SMALLINT_TYPE, INT_TYPE, BIGINT_TYPE,
+ FLOAT_TYPE, DOUBLE_TYPE, DECIMAL_TYPE,
+ DATE_TYPE, TIMESTAMP_TYPE,
+ ARRAY_TYPE, MAP_TYPE, STRUCT_TYPE)
}
private[thriftserver] val testedProtocolVersions = Seq(
diff --git a/sql/hive/benchmarks/ObjectHashAggregateExecBenchmark-results.txt b/sql/hive/benchmarks/ObjectHashAggregateExecBenchmark-results.txt
index f3044da972497..0c394a340333a 100644
--- a/sql/hive/benchmarks/ObjectHashAggregateExecBenchmark-results.txt
+++ b/sql/hive/benchmarks/ObjectHashAggregateExecBenchmark-results.txt
@@ -2,44 +2,44 @@
Hive UDAF vs Spark AF
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-hive udaf w/o group by 6370 / 6400 0.0 97193.6 1.0X
-spark af w/o group by 54 / 63 1.2 820.8 118.4X
-hive udaf w/ group by 4492 / 4507 0.0 68539.5 1.4X
-spark af w/ group by w/o fallback 58 / 64 1.1 881.7 110.2X
-spark af w/ group by w/ fallback 136 / 142 0.5 2075.0 46.8X
+hive udaf vs spark af: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+hive udaf w/o group by 6741 6759 22 0.0 102864.5 1.0X
+spark af w/o group by 56 66 9 1.2 851.6 120.8X
+hive udaf w/ group by 4610 4642 25 0.0 70350.3 1.5X
+spark af w/ group by w/o fallback 60 67 8 1.1 916.7 112.2X
+spark af w/ group by w/ fallback 135 144 9 0.5 2065.6 49.8X
================================================================================================
ObjectHashAggregateExec vs SortAggregateExec - typed_count
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-object agg v.s. sort agg: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-sort agg w/ group by 41500 / 41630 2.5 395.8 1.0X
-object agg w/ group by w/o fallback 10075 / 10122 10.4 96.1 4.1X
-object agg w/ group by w/ fallback 28131 / 28205 3.7 268.3 1.5X
-sort agg w/o group by 6182 / 6221 17.0 59.0 6.7X
-object agg w/o group by w/o fallback 5435 / 5468 19.3 51.8 7.6X
+object agg v.s. sort agg: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+sort agg w/ group by 41568 41894 461 2.5 396.4 1.0X
+object agg w/ group by w/o fallback 10314 10494 149 10.2 98.4 4.0X
+object agg w/ group by w/ fallback 26720 26951 326 3.9 254.8 1.6X
+sort agg w/o group by 6638 6681 38 15.8 63.3 6.3X
+object agg w/o group by w/o fallback 5665 5706 30 18.5 54.0 7.3X
================================================================================================
ObjectHashAggregateExec vs SortAggregateExec - percentile_approx
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-object agg v.s. sort agg: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-sort agg w/ group by 970 / 1025 2.2 462.5 1.0X
-object agg w/ group by w/o fallback 772 / 798 2.7 368.1 1.3X
-object agg w/ group by w/ fallback 1013 / 1044 2.1 483.1 1.0X
-sort agg w/o group by 751 / 781 2.8 358.0 1.3X
-object agg w/o group by w/o fallback 772 / 814 2.7 368.0 1.3X
+object agg v.s. sort agg: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+sort agg w/ group by 794 862 33 2.6 378.8 1.0X
+object agg w/ group by w/o fallback 605 622 10 3.5 288.5 1.3X
+object agg w/ group by w/ fallback 840 860 15 2.5 400.5 0.9X
+sort agg w/o group by 555 570 12 3.8 264.6 1.4X
+object agg w/o group by w/o fallback 544 562 12 3.9 259.6 1.5X
diff --git a/sql/hive/benchmarks/OrcReadBenchmark-results.txt b/sql/hive/benchmarks/OrcReadBenchmark-results.txt
index caa78b9a8f102..c47cf27bf617a 100644
--- a/sql/hive/benchmarks/OrcReadBenchmark-results.txt
+++ b/sql/hive/benchmarks/OrcReadBenchmark-results.txt
@@ -2,155 +2,155 @@
SQL Single Numeric Column Scan
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Native ORC MR 1725 / 1759 9.1 109.7 1.0X
-Native ORC Vectorized 272 / 316 57.8 17.3 6.3X
-Hive built-in ORC 1970 / 1987 8.0 125.3 0.9X
+SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Native ORC MR 1843 1958 162 8.5 117.2 1.0X
+Native ORC Vectorized 321 355 31 48.9 20.4 5.7X
+Hive built-in ORC 2143 2175 44 7.3 136.3 0.9X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Native ORC MR 1633 / 1672 9.6 103.8 1.0X
-Native ORC Vectorized 238 / 255 66.0 15.1 6.9X
-Hive built-in ORC 2293 / 2305 6.9 145.8 0.7X
+SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Native ORC MR 1987 2020 47 7.9 126.3 1.0X
+Native ORC Vectorized 276 299 25 57.0 17.6 7.2X
+Hive built-in ORC 2350 2357 10 6.7 149.4 0.8X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Native ORC MR 1677 / 1699 9.4 106.6 1.0X
-Native ORC Vectorized 325 / 342 48.3 20.7 5.2X
-Hive built-in ORC 2561 / 2569 6.1 162.8 0.7X
+SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Native ORC MR 2092 2115 32 7.5 133.0 1.0X
+Native ORC Vectorized 360 373 18 43.6 22.9 5.8X
+Hive built-in ORC 2550 2557 9 6.2 162.2 0.8X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Native ORC MR 1791 / 1795 8.8 113.9 1.0X
-Native ORC Vectorized 400 / 408 39.3 25.4 4.5X
-Hive built-in ORC 2713 / 2720 5.8 172.5 0.7X
+SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Native ORC MR 2173 2188 21 7.2 138.2 1.0X
+Native ORC Vectorized 435 448 14 36.2 27.7 5.0X
+Hive built-in ORC 2683 2690 10 5.9 170.6 0.8X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Native ORC MR 1791 / 1805 8.8 113.8 1.0X
-Native ORC Vectorized 433 / 438 36.3 27.5 4.1X
-Hive built-in ORC 2690 / 2803 5.8 171.0 0.7X
+SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Native ORC MR 2233 2323 127 7.0 142.0 1.0X
+Native ORC Vectorized 475 483 13 33.1 30.2 4.7X
+Hive built-in ORC 2605 2610 6 6.0 165.7 0.9X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Native ORC MR 1911 / 1930 8.2 121.5 1.0X
-Native ORC Vectorized 543 / 552 29.0 34.5 3.5X
-Hive built-in ORC 2967 / 3065 5.3 188.6 0.6X
+SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Native ORC MR 2367 2384 24 6.6 150.5 1.0X
+Native ORC Vectorized 600 641 69 26.2 38.1 3.9X
+Hive built-in ORC 2860 2877 24 5.5 181.9 0.8X
================================================================================================
Int and String Scan
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Native ORC MR 4160 / 4188 2.5 396.7 1.0X
-Native ORC Vectorized 2405 / 2406 4.4 229.4 1.7X
-Hive built-in ORC 5514 / 5562 1.9 525.9 0.8X
+Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Native ORC MR 4253 4330 108 2.5 405.6 1.0X
+Native ORC Vectorized 2295 2301 8 4.6 218.9 1.9X
+Hive built-in ORC 5364 5465 144 2.0 511.5 0.8X
================================================================================================
Partitioned Table Scan
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Data column - Native ORC MR 1863 / 1867 8.4 118.4 1.0X
-Data column - Native ORC Vectorized 411 / 418 38.2 26.2 4.5X
-Data column - Hive built-in ORC 3297 / 3308 4.8 209.6 0.6X
-Partition column - Native ORC MR 1505 / 1506 10.4 95.7 1.2X
-Partition column - Native ORC Vectorized 80 / 93 195.6 5.1 23.2X
-Partition column - Hive built-in ORC 1960 / 1979 8.0 124.6 1.0X
-Both columns - Native ORC MR 2076 / 2090 7.6 132.0 0.9X
-Both columns - Native ORC Vectorized 450 / 463 34.9 28.6 4.1X
-Both columns - Hive built-in ORC 3528 / 3548 4.5 224.3 0.5X
+Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Data column - Native ORC MR 2443 2448 6 6.4 155.3 1.0X
+Data column - Native ORC Vectorized 446 473 44 35.3 28.3 5.5X
+Data column - Hive built-in ORC 2868 2877 12 5.5 182.4 0.9X
+Partition column - Native ORC MR 1623 1656 47 9.7 103.2 1.5X
+Partition column - Native ORC Vectorized 112 121 14 140.8 7.1 21.9X
+Partition column - Hive built-in ORC 1846 1850 5 8.5 117.4 1.3X
+Both columns - Native ORC MR 2610 2635 36 6.0 165.9 0.9X
+Both columns - Native ORC Vectorized 492 508 19 32.0 31.3 5.0X
+Both columns - Hive built-in ORC 2969 2973 4 5.3 188.8 0.8X
================================================================================================
Repeated String Scan
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Native ORC MR 1727 / 1733 6.1 164.7 1.0X
-Native ORC Vectorized 375 / 379 28.0 35.7 4.6X
-Hive built-in ORC 2665 / 2666 3.9 254.2 0.6X
+Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Native ORC MR 2056 2064 11 5.1 196.1 1.0X
+Native ORC Vectorized 415 421 7 25.3 39.6 5.0X
+Hive built-in ORC 2710 2722 17 3.9 258.4 0.8X
================================================================================================
String with Nulls Scan
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Native ORC MR 3324 / 3325 3.2 317.0 1.0X
-Native ORC Vectorized 1085 / 1106 9.7 103.4 3.1X
-Hive built-in ORC 5272 / 5299 2.0 502.8 0.6X
+String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Native ORC MR 3655 3674 27 2.9 348.6 1.0X
+Native ORC Vectorized 1166 1167 1 9.0 111.2 3.1X
+Hive built-in ORC 5268 5305 52 2.0 502.4 0.7X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-String with Nulls Scan (50.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Native ORC MR 3045 / 3046 3.4 290.4 1.0X
-Native ORC Vectorized 1248 / 1260 8.4 119.0 2.4X
-Hive built-in ORC 3989 / 3999 2.6 380.4 0.8X
+String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Native ORC MR 3447 3467 27 3.0 328.8 1.0X
+Native ORC Vectorized 1222 1223 1 8.6 116.6 2.8X
+Hive built-in ORC 3947 3959 18 2.7 376.4 0.9X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-String with Nulls Scan (95.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Native ORC MR 1692 / 1694 6.2 161.3 1.0X
-Native ORC Vectorized 471 / 493 22.3 44.9 3.6X
-Hive built-in ORC 2398 / 2411 4.4 228.7 0.7X
+String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Native ORC MR 1912 1917 6 5.5 182.4 1.0X
+Native ORC Vectorized 477 484 5 22.0 45.5 4.0X
+Hive built-in ORC 2374 2386 17 4.4 226.4 0.8X
================================================================================================
Single Column Scan From Wide Columns
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Native ORC MR 1371 / 1379 0.8 1307.5 1.0X
-Native ORC Vectorized 121 / 135 8.6 115.8 11.3X
-Hive built-in ORC 521 / 561 2.0 497.1 2.6X
+Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Native ORC MR 290 350 102 3.6 276.1 1.0X
+Native ORC Vectorized 155 166 15 6.7 148.2 1.9X
+Hive built-in ORC 520 531 8 2.0 495.8 0.6X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Native ORC MR 2711 / 2767 0.4 2585.5 1.0X
-Native ORC Vectorized 210 / 232 5.0 200.5 12.9X
-Hive built-in ORC 764 / 775 1.4 728.3 3.5X
+Single Column Scan from 200 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Native ORC MR 365 406 73 2.9 347.9 1.0X
+Native ORC Vectorized 232 246 20 4.5 221.6 1.6X
+Hive built-in ORC 794 864 62 1.3 757.6 0.5X
-OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_222-b10 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------
-Native ORC MR 3979 / 3988 0.3 3794.4 1.0X
-Native ORC Vectorized 357 / 366 2.9 340.2 11.2X
-Hive built-in ORC 1091 / 1095 1.0 1040.5 3.6X
+Single Column Scan from 300 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Native ORC MR 501 544 40 2.1 477.6 1.0X
+Native ORC Vectorized 365 386 33 2.9 348.0 1.4X
+Hive built-in ORC 1153 1153 0 0.9 1099.8 0.4X
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index e7ff3a5f4be2b..7a9f5c67fc693 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -46,7 +46,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f)
}
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
TestHive.setCacheTables(true)
// Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
@@ -65,7 +65,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
RuleExecutor.resetMetrics()
}
- override def afterAll() {
+ override def afterAll(): Unit = {
try {
TestHive.setCacheTables(false)
TimeZone.setDefault(originalTimeZone)
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
index c7d953a731b9b..b0cf25c3a7813 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
@@ -37,7 +37,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte
private val originalLocale = Locale.getDefault
private val testTempDir = Utils.createTempDir()
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
TestHive.setCacheTables(true)
// Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
@@ -100,7 +100,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte
sql("set mapreduce.jobtracker.address=local")
}
- override def afterAll() {
+ override def afterAll(): Unit = {
try {
TestHive.setCacheTables(false)
TimeZone.setDefault(originalTimeZone)
@@ -751,7 +751,7 @@ class HiveWindowFunctionQueryFileSuite
private val originalLocale = Locale.getDefault
private val testTempDir = Utils.createTempDir()
- override def beforeAll() {
+ override def beforeAll(): Unit = {
super.beforeAll()
TestHive.setCacheTables(true)
// Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
@@ -769,7 +769,7 @@ class HiveWindowFunctionQueryFileSuite
// sql("set mapreduce.jobtracker.address=local")
}
- override def afterAll() {
+ override def afterAll(): Unit = {
try {
TestHive.setCacheTables(false)
TimeZone.setDefault(originalTimeZone)
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index d37f0c8573659..f627227aa0380 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -103,14 +103,6 @@
${hive.group}
hive-metastore
-
- ${hive.group}
- hive-contrib
-
-
- ${hive.group}.hcatalog
- hive-hcatalog-core
-