diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml index 4282504cc3984..fe01b92036377 100644 --- a/.github/workflows/master.yml +++ b/.github/workflows/master.yml @@ -9,148 +9,231 @@ on: - master jobs: + # TODO(SPARK-32248): Recover JDK 11 builds + # Build: build Spark and run the tests for specified modules. build: - + name: "Build modules: ${{ matrix.modules }} ${{ matrix.comment }} (JDK ${{ matrix.java }}, ${{ matrix.hadoop }}, ${{ matrix.hive }})" runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - java: [ '1.8', '11' ] - hadoop: [ 'hadoop-2.7', 'hadoop-3.2' ] - hive: [ 'hive-1.2', 'hive-2.3' ] - exclude: - - java: '11' - hive: 'hive-1.2' - - hadoop: 'hadoop-3.2' - hive: 'hive-1.2' - name: Build Spark - JDK${{ matrix.java }}/${{ matrix.hadoop }}/${{ matrix.hive }} - + java: + - 1.8 + hadoop: + - hadoop3.2 + hive: + - hive2.3 + # TODO(SPARK-32246): We don't test 'streaming-kinesis-asl' for now. + # Kinesis tests depends on external Amazon kinesis service. + # Note that the modules below are from sparktestsupport/modules.py. + modules: + - |- + core, unsafe, kvstore, avro, + network-common, network-shuffle, repl, launcher, + examples, sketch, graphx + - |- + catalyst, hive-thriftserver + - |- + streaming, sql-kafka-0-10, streaming-kafka-0-10, + mllib-local, mllib, + yarn, mesos, kubernetes, hadoop-cloud, spark-ganglia-lgpl + - |- + pyspark-sql, pyspark-mllib, pyspark-resource + - |- + pyspark-core, pyspark-streaming, pyspark-ml + - |- + sparkr + # Here, we split Hive and SQL tests into some of slow ones and the rest of them. + included-tags: [""] + excluded-tags: [""] + comment: [""] + include: + # Hive tests + - modules: hive + java: 1.8 + hadoop: hadoop3.2 + hive: hive2.3 + included-tags: org.apache.spark.tags.SlowHiveTest + comment: "- slow tests" + - modules: hive + java: 1.8 + hadoop: hadoop3.2 + hive: hive2.3 + excluded-tags: org.apache.spark.tags.SlowHiveTest + comment: "- other tests" + # SQL tests + - modules: sql + java: 1.8 + hadoop: hadoop3.2 + hive: hive2.3 + included-tags: org.apache.spark.tags.ExtendedSQLTest + comment: "- slow tests" + - modules: sql + java: 1.8 + hadoop: hadoop3.2 + hive: hive2.3 + excluded-tags: org.apache.spark.tags.ExtendedSQLTest + comment: "- other tests" + env: + MODULES_TO_TEST: ${{ matrix.modules }} + EXCLUDED_TAGS: ${{ matrix.excluded-tags }} + INCLUDED_TAGS: ${{ matrix.included-tags }} + HADOOP_PROFILE: ${{ matrix.hadoop }} + HIVE_PROFILE: ${{ matrix.hive }} + # GitHub Actions' default miniconda to use in pip packaging test. + CONDA_PREFIX: /usr/share/miniconda + GITHUB_PREV_SHA: ${{ github.event.before }} steps: - - uses: actions/checkout@master - # We split caches because GitHub Action Cache has a 400MB-size limit. - - uses: actions/cache@v1 + - name: Checkout Spark repository + uses: actions/checkout@v2 + # In order to fetch changed files + with: + fetch-depth: 0 + # Cache local repositories. Note that GitHub Actions cache has a 2G limit. + - name: Cache Scala, SBT, Maven and Zinc + uses: actions/cache@v1 with: path: build key: build-${{ hashFiles('**/pom.xml') }} restore-keys: | build- - - uses: actions/cache@v1 + - name: Cache Maven local repository + uses: actions/cache@v2 with: - path: ~/.m2/repository/com - key: ${{ matrix.java }}-${{ matrix.hadoop }}-maven-com-${{ hashFiles('**/pom.xml') }} - restore-keys: | - ${{ matrix.java }}-${{ matrix.hadoop }}-maven-com- - - uses: actions/cache@v1 - with: - path: ~/.m2/repository/org - key: ${{ matrix.java }}-${{ matrix.hadoop }}-maven-org-${{ hashFiles('**/pom.xml') }} - restore-keys: | - ${{ matrix.java }}-${{ matrix.hadoop }}-maven-org- - - uses: actions/cache@v1 - with: - path: ~/.m2/repository/net - key: ${{ matrix.java }}-${{ matrix.hadoop }}-maven-net-${{ hashFiles('**/pom.xml') }} + path: ~/.m2/repository + key: ${{ matrix.java }}-${{ matrix.hadoop }}-maven-${{ hashFiles('**/pom.xml') }} restore-keys: | - ${{ matrix.java }}-${{ matrix.hadoop }}-maven-net- - - uses: actions/cache@v1 + ${{ matrix.java }}-${{ matrix.hadoop }}-maven- + - name: Cache Ivy local repository + uses: actions/cache@v2 with: - path: ~/.m2/repository/io - key: ${{ matrix.java }}-${{ matrix.hadoop }}-maven-io-${{ hashFiles('**/pom.xml') }} + path: ~/.ivy2/cache + key: ${{ matrix.java }}-${{ matrix.hadoop }}-ivy-${{ hashFiles('**/pom.xml') }}-${{ hashFiles('**/plugins.sbt') }} restore-keys: | - ${{ matrix.java }}-${{ matrix.hadoop }}-maven-io- - - name: Set up JDK ${{ matrix.java }} + ${{ matrix.java }}-${{ matrix.hadoop }}-ivy- + - name: Install JDK ${{ matrix.java }} uses: actions/setup-java@v1 with: java-version: ${{ matrix.java }} - - name: Build with Maven - run: | - export MAVEN_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=1g -Dorg.slf4j.simpleLogger.defaultLogLevel=WARN" - export MAVEN_CLI_OPTS="--no-transfer-progress" - mkdir -p ~/.m2 - ./build/mvn $MAVEN_CLI_OPTS -DskipTests -Pyarn -Pmesos -Pkubernetes -Phive -P${{ matrix.hive }} -Phive-thriftserver -P${{ matrix.hadoop }} -Phadoop-cloud -Djava.version=${{ matrix.java }} install - rm -rf ~/.m2/repository/org/apache/spark - - - lint: - runs-on: ubuntu-latest - name: Linters (Java/Scala/Python), licenses, dependencies - steps: - - uses: actions/checkout@master - - uses: actions/setup-java@v1 + # PySpark + - name: Install PyPy3 + # Note that order of Python installations here matters because default python3 is + # overridden by pypy3. + uses: actions/setup-python@v2 + if: contains(matrix.modules, 'pyspark') with: - java-version: '11' - - uses: actions/setup-python@v1 + python-version: pypy3 + architecture: x64 + - name: Install Python 3.6 + uses: actions/setup-python@v2 + if: contains(matrix.modules, 'pyspark') with: - python-version: '3.x' - architecture: 'x64' - - name: Scala - run: ./dev/lint-scala - - name: Java - run: ./dev/lint-java - - name: Python - run: | - pip install flake8 sphinx numpy - ./dev/lint-python - - name: License - run: ./dev/check-license - - name: Dependencies - run: ./dev/test-dependencies.sh - - lintr: - runs-on: ubuntu-latest - name: Linter (R) - steps: - - uses: actions/checkout@master - - uses: actions/setup-java@v1 + python-version: 3.6 + architecture: x64 + - name: Install Python 3.8 + uses: actions/setup-python@v2 + # We should install one Python that is higher then 3+ for SQL and Yarn because: + # - SQL component also has Python related tests, for example, IntegratedUDFTestUtils. + # - Yarn has a Python specific test too, for example, YarnClusterSuite. + if: contains(matrix.modules, 'yarn') || contains(matrix.modules, 'pyspark') || (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) with: - java-version: '11' - - uses: r-lib/actions/setup-r@v1 + python-version: 3.8 + architecture: x64 + - name: Install Python packages (Python 3.6 and PyPy3) + if: contains(matrix.modules, 'pyspark') + # PyArrow is not supported in PyPy yet, see ARROW-2651. + # TODO(SPARK-32247): scipy installation with PyPy fails for an unknown reason. + run: | + python3.6 -m pip install numpy pyarrow pandas scipy + python3.6 -m pip list + pypy3 -m pip install numpy pandas + pypy3 -m pip list + - name: Install Python packages (Python 3.8) + if: contains(matrix.modules, 'pyspark') || (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) + run: | + python3.8 -m pip install numpy pyarrow pandas scipy + python3.8 -m pip list + # SparkR + - name: Install R 3.6 + uses: r-lib/actions/setup-r@v1 + if: contains(matrix.modules, 'sparkr') with: - r-version: '3.6.2' - - name: Install lib + r-version: 3.6 + - name: Install R packages + if: contains(matrix.modules, 'sparkr') run: | sudo apt-get install -y libcurl4-openssl-dev - - name: install R packages + sudo Rscript -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'devtools', 'e1071', 'survival', 'arrow', 'roxygen2'), repos='https://cloud.r-project.org/')" + # Show installed packages in R. + sudo Rscript -e 'pkg_list <- as.data.frame(installed.packages()[, c(1,3:4)]); pkg_list[is.na(pkg_list$Priority), 1:2, drop = FALSE]' + # Run the tests. + - name: "Run tests: ${{ matrix.modules }}" run: | - sudo Rscript -e "install.packages(c('curl', 'xml2', 'httr', 'devtools', 'testthat', 'knitr', 'rmarkdown', 'roxygen2', 'e1071', 'survival'), repos='https://cloud.r-project.org/')" - sudo Rscript -e "devtools::install_github('jimhester/lintr@v2.0.0')" - - name: package and install SparkR - run: ./R/install-dev.sh - - name: lint-r - run: ./dev/lint-r + # Hive tests become flaky when running in parallel as it's too intensive. + if [[ "$MODULES_TO_TEST" == "hive" ]]; then export SERIAL_SBT_TESTS=1; fi + mkdir -p ~/.m2 + ./dev/run-tests --parallelism 2 --modules "$MODULES_TO_TEST" --included-tags "$INCLUDED_TAGS" --excluded-tags "$EXCLUDED_TAGS" + rm -rf ~/.m2/repository/org/apache/spark - docs: + # Static analysis, and documentation build + lint: + name: Linters, licenses, dependencies and documentation generation runs-on: ubuntu-latest - name: Generate documents steps: - - uses: actions/checkout@master - - uses: actions/cache@v1 + - name: Checkout Spark repository + uses: actions/checkout@v2 + - name: Cache Maven local repository + uses: actions/cache@v2 with: path: ~/.m2/repository key: docs-maven-repo-${{ hashFiles('**/pom.xml') }} restore-keys: | - docs-maven-repo- - - uses: actions/setup-java@v1 + docs-maven- + - name: Install JDK 1.8 + uses: actions/setup-java@v1 with: - java-version: '1.8' - - uses: actions/setup-python@v1 + java-version: 1.8 + - name: Install Python 3.6 + uses: actions/setup-python@v2 with: - python-version: '3.x' - architecture: 'x64' - - uses: actions/setup-ruby@v1 + python-version: 3.6 + architecture: x64 + - name: Install Python linter dependencies + run: | + pip3 install flake8 sphinx numpy + - name: Install R 3.6 + uses: r-lib/actions/setup-r@v1 with: - ruby-version: '2.7' - - uses: r-lib/actions/setup-r@v1 + r-version: 3.6 + - name: Install R linter dependencies and SparkR + run: | + sudo apt-get install -y libcurl4-openssl-dev + sudo Rscript -e "install.packages(c('devtools'), repos='https://cloud.r-project.org/')" + sudo Rscript -e "devtools::install_github('jimhester/lintr@v2.0.0')" + ./R/install-dev.sh + - name: Install Ruby 2.7 for documentation generation + uses: actions/setup-ruby@v1 with: - r-version: '3.6.2' - - name: Install lib and pandoc + ruby-version: 2.7 + - name: Install dependencies for documentation generation run: | sudo apt-get install -y libcurl4-openssl-dev pandoc - - name: Install packages - run: | pip install sphinx mkdocs numpy gem install jekyll jekyll-redirect-from rouge - sudo Rscript -e "install.packages(c('curl', 'xml2', 'httr', 'devtools', 'testthat', 'knitr', 'rmarkdown', 'roxygen2', 'e1071', 'survival'), repos='https://cloud.r-project.org/')" - - name: Run jekyll build + sudo Rscript -e "install.packages(c('devtools', 'testthat', 'knitr', 'rmarkdown', 'roxygen2'), repos='https://cloud.r-project.org/')" + - name: Scala linter + run: ./dev/lint-scala + - name: Java linter + run: ./dev/lint-java + - name: Python linter + run: ./dev/lint-python + - name: R linter + run: ./dev/lint-r + - name: License test + run: ./dev/check-license + - name: Dependencies test + run: ./dev/test-dependencies.sh + - name: Run documentation build run: | cd docs jekyll build diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index f86872d727a1d..1add5a9fdde44 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -139,7 +139,7 @@ test_that("utility function can be called", { expect_true(TRUE) }) -test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", { +test_that("getClientModeSparkSubmitOpts() returns spark-submit args from allowList", { e <- new.env() e[["spark.driver.memory"]] <- "512m" ops <- getClientModeSparkSubmitOpts("sparkrmain", e) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 611d9057c0f13..e008bc5bbd7d9 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -3921,14 +3921,14 @@ test_that("No extra files are created in SPARK_HOME by starting session and maki # before creating a SparkSession with enableHiveSupport = T at the top of this test file # (filesBefore). The test here is to compare that (filesBefore) against the list of files before # any test is run in run-all.R (sparkRFilesBefore). - # sparkRWhitelistSQLDirs is also defined in run-all.R, and should contain only 2 whitelisted dirs, + # sparkRAllowedSQLDirs is also defined in run-all.R, and should contain only 2 allowed dirs, # here allow the first value, spark-warehouse, in the diff, everything else should be exactly the # same as before any test is run. - compare_list(sparkRFilesBefore, setdiff(filesBefore, sparkRWhitelistSQLDirs[[1]])) + compare_list(sparkRFilesBefore, setdiff(filesBefore, sparkRAllowedSQLDirs[[1]])) # third, ensure only spark-warehouse and metastore_db are created when enableHiveSupport = T # note: as the note above, after running all tests in this file while enableHiveSupport = T, we - # check the list of files again. This time we allow both whitelisted dirs to be in the diff. - compare_list(sparkRFilesBefore, setdiff(filesAfter, sparkRWhitelistSQLDirs)) + # check the list of files again. This time we allow both dirs to be in the diff. + compare_list(sparkRFilesBefore, setdiff(filesAfter, sparkRAllowedSQLDirs)) }) unlink(parquetPath) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index bf02ecdad66ff..a46924a5d20e3 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -35,8 +35,8 @@ if (identical(Sys.getenv("NOT_CRAN"), "true")) { install.spark(overwrite = TRUE) sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") - sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") - invisible(lapply(sparkRWhitelistSQLDirs, + sparkRAllowedSQLDirs <- c("spark-warehouse", "metastore_db") + invisible(lapply(sparkRAllowedSQLDirs, function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) diff --git a/appveyor.yml b/appveyor.yml index a4da5f9040ded..1fd91daae9015 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -42,8 +42,8 @@ install: # Install maven and dependencies - ps: .\dev\appveyor-install-dependencies.ps1 # Required package for R unit tests - - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival', 'arrow'), repos='https://cloud.r-project.org/')" - - cmd: R -e "packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival'); packageVersion('arrow')" + - cmd: Rscript -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival', 'arrow'), repos='https://cloud.r-project.org/')" + - cmd: Rscript -e "pkg_list <- as.data.frame(installed.packages()[,c(1, 3:4)]); pkg_list[is.na(pkg_list$Priority), 1:2, drop = FALSE]" build_script: # '-Djna.nosys=true' is required to avoid kernel32.dll load failure. diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md index 14df703270498..7a9fa3a91d143 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md @@ -155,4 +155,4 @@ server will be able to understand. This will cause the server to close the conne attacker tries to send any command to the server. The attacker can just hold the channel open for some time, which will be closed when the server times out the channel. These issues could be separately mitigated by adding a shorter timeout for the first message after authentication, and -potentially by adding host blacklists if a possible attack is detected from a particular host. +potentially by adding host reject-lists if a possible attack is detected from a particular host. diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorDiskUtils.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorDiskUtils.java index 6549cac011feb..e5e61aae92d2f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorDiskUtils.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorDiskUtils.java @@ -18,25 +18,11 @@ package org.apache.spark.network.shuffle; import java.io.File; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import com.google.common.annotations.VisibleForTesting; - -import org.apache.commons.lang3.SystemUtils; import org.apache.spark.network.util.JavaUtils; public class ExecutorDiskUtils { - private static final Pattern MULTIPLE_SEPARATORS; - static { - if (SystemUtils.IS_OS_WINDOWS) { - MULTIPLE_SEPARATORS = Pattern.compile("[/\\\\]+"); - } else { - MULTIPLE_SEPARATORS = Pattern.compile("/{2,}"); - } - } - /** * Hashes a filename into the corresponding local directory, in a manner consistent with * Spark's DiskBlockManager.getFile(). @@ -45,34 +31,16 @@ public static File getFile(String[] localDirs, int subDirsPerLocalDir, String fi int hash = JavaUtils.nonNegativeHash(filename); String localDir = localDirs[hash % localDirs.length]; int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; - return new File(createNormalizedInternedPathname( - localDir, String.format("%02x", subDirId), filename)); - } - - /** - * This method is needed to avoid the situation when multiple File instances for the - * same pathname "foo/bar" are created, each with a separate copy of the "foo/bar" String. - * According to measurements, in some scenarios such duplicate strings may waste a lot - * of memory (~ 10% of the heap). To avoid that, we intern the pathname, and before that - * we make sure that it's in a normalized form (contains no "//", "///" etc.) Otherwise, - * the internal code in java.io.File would normalize it later, creating a new "foo/bar" - * String copy. Unfortunately, we cannot just reuse the normalization code that java.io.File - * uses, since it is in the package-private class java.io.FileSystem. - * - * On Windows, separator "\" is used instead of "/". - * - * "\\" is a legal character in path name on Unix-like OS, but illegal on Windows. - */ - @VisibleForTesting - static String createNormalizedInternedPathname(String dir1, String dir2, String fname) { - String pathname = dir1 + File.separator + dir2 + File.separator + fname; - Matcher m = MULTIPLE_SEPARATORS.matcher(pathname); - pathname = m.replaceAll(Matcher.quoteReplacement(File.separator)); - // A single trailing slash needs to be taken care of separately - if (pathname.length() > 1 && pathname.charAt(pathname.length() - 1) == File.separatorChar) { - pathname = pathname.substring(0, pathname.length() - 1); - } - return pathname.intern(); + final String notNormalizedPath = + localDir + File.separator + String.format("%02x", subDirId) + File.separator + filename; + // Interning the normalized path as according to measurements, in some scenarios such + // duplicate strings may waste a lot of memory (~ 10% of the heap). + // Unfortunately, we cannot just call the normalization code that java.io.File + // uses, since it is in the package-private class java.io.FileSystem. + // So we are creating a File just to get the normalized path back to intern it. + // Finally a new File is built and returned with this interned normalized path. + final String normalizedInternedPath = new File(notNormalizedPath).getPath().intern(); + return new File(normalizedInternedPath); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index ba1a17bf7e5ea..a6bcbb8850566 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -24,7 +24,6 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Executors; -import java.util.regex.Pattern; import java.util.stream.Collectors; import org.apache.commons.lang3.builder.ToStringBuilder; @@ -71,8 +70,6 @@ public class ExternalShuffleBlockResolver { private static final String APP_KEY_PREFIX = "AppExecShuffleInfo"; private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0); - private static final Pattern MULTIPLE_SEPARATORS = Pattern.compile(File.separator + "{2,}"); - // Map containing all registered executors' metadata. @VisibleForTesting final ConcurrentMap executors; diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 6515b6ca035f7..88bcf43c2371f 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.network.shuffle; -import java.io.File; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; @@ -25,7 +24,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.io.CharStreams; -import org.apache.commons.lang3.SystemUtils; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -145,29 +143,4 @@ public void jsonSerializationOfExecutorRegistration() throws IOException { assertEquals(shuffleInfo, mapper.readValue(legacyShuffleJson, ExecutorShuffleInfo.class)); } - @Test - public void testNormalizeAndInternPathname() { - String sep = File.separator; - String expectedPathname = sep + "foo" + sep + "bar" + sep + "baz"; - assertPathsMatch("/foo", "bar", "baz", expectedPathname); - assertPathsMatch("//foo/", "bar/", "//baz", expectedPathname); - assertPathsMatch("/foo/", "/bar//", "/baz", expectedPathname); - assertPathsMatch("foo", "bar", "baz///", "foo" + sep + "bar" + sep + "baz"); - assertPathsMatch("/", "", "", sep); - assertPathsMatch("/", "/", "/", sep); - if (SystemUtils.IS_OS_WINDOWS) { - assertPathsMatch("/foo\\/", "bar", "baz", expectedPathname); - } else { - assertPathsMatch("/foo\\/", "bar", "baz", sep + "foo\\" + sep + "bar" + sep + "baz"); - } - } - - private void assertPathsMatch(String p1, String p2, String p3, String expectedPathname) { - String normPathname = - ExecutorDiskUtils.createNormalizedInternedPathname(p1, p2, p3); - assertEquals(expectedPathname, normPathname); - File file = new File(normPathname); - String returnedPath = file.getPath(); - assertEquals(normPathname, returnedPath); - } } diff --git a/common/tags/src/test/java/org/apache/spark/tags/SlowHiveTest.java b/common/tags/src/test/java/org/apache/spark/tags/SlowHiveTest.java new file mode 100644 index 0000000000000..a7e6f352667d7 --- /dev/null +++ b/common/tags/src/test/java/org/apache/spark/tags/SlowHiveTest.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.tags; + +import org.scalatest.TagAnnotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface SlowHiveTest { } diff --git a/conf/slaves.template b/conf/workers.template similarity index 100% rename from conf/slaves.template rename to conf/workers.template diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala b/core/src/main/java/org/apache/spark/status/api/v1/TaskStatus.java similarity index 74% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala rename to core/src/main/java/org/apache/spark/status/api/v1/TaskStatus.java index 47b1f78b24505..dec9c31321839 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala +++ b/core/src/main/java/org/apache/spark/status/api/v1/TaskStatus.java @@ -15,13 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming.continuous.shuffle +package org.apache.spark.status.api.v1; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.EnumUtil; -/** - * Trait for writing to a continuous processing shuffle. - */ -trait ContinuousShuffleWriter { - def write(epoch: Iterator[UnsafeRow]): Unit +public enum TaskStatus { + RUNNING, + KILLED, + FAILED, + SUCCESS, + UNKNOWN; + + public static TaskStatus fromString(String str) { + return EnumUtil.parseIgnoreCase(TaskStatus.class, str); + } } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index e0ac2b3e0f4b8..620a6fe2f9d72 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -550,7 +550,7 @@ private[spark] class ExecutorAllocationManager( } else { // We don't want to change our target number of executors, because we already did that // when the task backlog decreased. - client.killExecutors(executorIdsToBeRemoved, adjustTargetNumExecutors = false, + client.killExecutors(executorIdsToBeRemoved.toSeq, adjustTargetNumExecutors = false, countFailures = false, force = false) } @@ -563,9 +563,9 @@ private[spark] class ExecutorAllocationManager( // reset the newExecutorTotal to the existing number of executors if (testing || executorsRemoved.nonEmpty) { - executorMonitor.executorsKilled(executorsRemoved) + executorMonitor.executorsKilled(executorsRemoved.toSeq) logInfo(s"Executors ${executorsRemoved.mkString(",")} removed due to idle timeout.") - executorsRemoved + executorsRemoved.toSeq } else { logWarning(s"Unable to reach the cluster manager to kill executor/s " + s"${executorIdsToBeRemoved.mkString(",")} or no executor eligible to kill!") diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 2ac72e66d6f32..c99698f99d904 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -80,7 +80,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) // executor ID -> timestamp of when the last heartbeat from this executor was received private val executorLastSeen = new HashMap[String, Long] - private val executorTimeoutMs = sc.conf.get(config.STORAGE_BLOCKMANAGER_SLAVE_TIMEOUT) + private val executorTimeoutMs = sc.conf.get(config.STORAGE_BLOCKMANAGER_HEARTBEAT_TIMEOUT) private val checkTimeoutIntervalMs = sc.conf.get(Network.NETWORK_TIMEOUT_INTERVAL) @@ -88,10 +88,10 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) require(checkTimeoutIntervalMs <= executorTimeoutMs, s"${Network.NETWORK_TIMEOUT_INTERVAL.key} should be less than or " + - s"equal to ${config.STORAGE_BLOCKMANAGER_SLAVE_TIMEOUT.key}.") + s"equal to ${config.STORAGE_BLOCKMANAGER_HEARTBEAT_TIMEOUT.key}.") require(executorHeartbeatIntervalMs <= executorTimeoutMs, s"${config.EXECUTOR_HEARTBEAT_INTERVAL.key} should be less than or " + - s"equal to ${config.STORAGE_BLOCKMANAGER_SLAVE_TIMEOUT.key}") + s"equal to ${config.STORAGE_BLOCKMANAGER_HEARTBEAT_TIMEOUT.key}") private var timeoutCheckingTask: ScheduledFuture[_] = null @@ -218,7 +218,8 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) sc.schedulerBackend match { case backend: CoarseGrainedSchedulerBackend => backend.driverEndpoint.send(RemoveExecutor(executorId, - SlaveLost(s"Executor heartbeat timed out after ${now - lastSeenMs} ms"))) + ExecutorProcessLost( + s"Executor heartbeat timed out after ${now - lastSeenMs} ms"))) // LocalSchedulerBackend is used locally and only has one single executor case _: LocalSchedulerBackend => diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 18cd5de4cfada..32251df6f4bbe 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -972,6 +972,6 @@ private[spark] object MapOutputTracker extends Logging { } } - splitsByAddress.iterator + splitsByAddress.mapValues(_.toSeq).iterator } } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 40915e3904f7e..dbd89d646ae54 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -173,15 +173,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria this } - /** - * Set multiple parameters together - */ - @deprecated("Use setAll(Iterable) instead", "3.0.0") - def setAll(settings: Traversable[(String, String)]): SparkConf = { - settings.foreach { case (k, v) => set(k, v) } - this - } - /** Set a parameter if it isn't already configured */ def setIfMissing(key: String, value: String): SparkConf = { if (settings.putIfAbsent(key, value) == null) { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 38d7319b1f0ef..06abc0541a9a9 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -83,6 +83,9 @@ class SparkContext(config: SparkConf) extends Logging { // The call site where this SparkContext was constructed. private val creationSite: CallSite = Utils.getCallSite() + // In order to prevent SparkContext from being created in executors. + SparkContext.assertOnDriver() + // In order to prevent multiple SparkContexts from being active at the same time, mark this // context as having started construction. // NOTE: this must be placed at the beginning of the SparkContext constructor. @@ -1729,7 +1732,7 @@ class SparkContext(config: SparkConf) extends Logging { def version: String = SPARK_VERSION /** - * Return a map from the slave to the max memory available for caching and the remaining + * Return a map from the block manager to the max memory available for caching and the remaining * memory available for caching. */ def getExecutorMemoryStatus: Map[String, (Long, Long)] = { @@ -2554,6 +2557,19 @@ object SparkContext extends Logging { } } + /** + * Called to ensure that SparkContext is created or accessed only on the Driver. + * + * Throws an exception if a SparkContext is about to be created in executors. + */ + private def assertOnDriver(): Unit = { + if (TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException( + "SparkContext should only be created and accessed on the driver.") + } + } + /** * This function may be used to get or instantiate a SparkContext and register it as a * singleton object. Because we can only have one active SparkContext per JVM, @@ -2814,14 +2830,14 @@ object SparkContext extends Logging { scheduler.initialize(backend) (backend, scheduler) - case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => - checkResourcesPerTask(coresPerSlave.toInt) - // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. - val memoryPerSlaveInt = memoryPerSlave.toInt - if (sc.executorMemory > memoryPerSlaveInt) { + case LOCAL_CLUSTER_REGEX(numWorkers, coresPerWorker, memoryPerWorker) => + checkResourcesPerTask(coresPerWorker.toInt) + // Check to make sure memory requested <= memoryPerWorker. Otherwise Spark will just hang. + val memoryPerWorkerInt = memoryPerWorker.toInt + if (sc.executorMemory > memoryPerWorkerInt) { throw new SparkException( "Asked to launch cluster with %d MiB RAM / worker but requested %d MiB/worker".format( - memoryPerSlaveInt, sc.executorMemory)) + memoryPerWorkerInt, sc.executorMemory)) } // For host local mode setting the default of SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED @@ -2834,7 +2850,7 @@ object SparkContext extends Logging { val scheduler = new TaskSchedulerImpl(sc) val localCluster = new LocalSparkCluster( - numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt, sc.conf) + numWorkers.toInt, coresPerWorker.toInt, memoryPerWorkerInt, sc.conf) val masterUrls = localCluster.start() val backend = new StandaloneSchedulerBackend(scheduler, sc, masterUrls) scheduler.initialize(backend) diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index 555c085d85a1e..37e673cd8c7e1 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -21,6 +21,7 @@ import java.util.Arrays import org.apache.spark.status.AppStatusStore import org.apache.spark.status.api.v1.StageStatus +import org.apache.spark.util.Utils /** * Low-level status reporting APIs for monitoring job and stage progress. @@ -103,10 +104,7 @@ class SparkStatusTracker private[spark] (sc: SparkContext, store: AppStatusStore */ def getExecutorInfos: Array[SparkExecutorInfo] = { store.executorList(true).map { exec => - val (host, port) = exec.hostPort.split(":", 2) match { - case Array(h, p) => (h, p.toInt) - case Array(h) => (h, -1) - } + val (host, port) = Utils.parseHostPort(exec.hostPort) val cachedMem = exec.memoryMetrics.map { mem => mem.usedOnHeapStorageMemory + mem.usedOffHeapStorageMemory }.getOrElse(0L) diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 08a58a029528b..db4b74bb89f0c 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -111,7 +111,7 @@ private[spark] class TaskContextImpl( if (failed) return failed = true failure = error - invokeListeners(onFailureCallbacks, "TaskFailureListener", Option(error)) { + invokeListeners(onFailureCallbacks.toSeq, "TaskFailureListener", Option(error)) { _.onTaskFailure(this, error) } } @@ -120,7 +120,7 @@ private[spark] class TaskContextImpl( private[spark] override def markTaskCompleted(error: Option[Throwable]): Unit = synchronized { if (completed) return completed = true - invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) { + invokeListeners(onCompleteCallbacks.toSeq, "TaskCompletionListener", error) { _.onTaskCompletion(this) } } @@ -142,7 +142,7 @@ private[spark] class TaskContextImpl( } } if (errorMsgs.nonEmpty) { - throw new TaskCompletionListenerException(errorMsgs, error) + throw new TaskCompletionListenerException(errorMsgs.toSeq, error) } } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index b13028f868072..6606d317e7b86 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -90,7 +90,8 @@ case class FetchFailed( extends TaskFailedReason { override def toErrorString: String = { val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString - s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapIndex=$mapIndex, " + + val mapIndexString = if (mapIndex == Int.MinValue) "Unknown" else mapIndex.toString + s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapIndex=$mapIndexString, " + s"mapId=$mapId, reduceId=$reduceId, message=\n$message\n)" } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index e4140f659d979..15cb01a173287 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -256,7 +256,7 @@ object JavaRDD { } catch { case eof: EOFException => // No-op } - JavaRDD.fromRDD(sc.parallelize(objs, parallelism)) + JavaRDD.fromRDD(sc.parallelize(objs.toSeq, parallelism)) } finally { din.close() } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 1ca5262742665..89b33945dfb08 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -265,14 +265,14 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return an RDD created by piping elements to a forked external process. */ def pipe(command: JList[String]): JavaRDD[String] = { - rdd.pipe(command.asScala) + rdd.pipe(command.asScala.toSeq) } /** * Return an RDD created by piping elements to a forked external process. */ def pipe(command: JList[String], env: JMap[String, String]): JavaRDD[String] = { - rdd.pipe(command.asScala, env.asScala) + rdd.pipe(command.asScala.toSeq, env.asScala) } /** @@ -282,7 +282,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { env: JMap[String, String], separateWorkingDir: Boolean, bufferSize: Int): JavaRDD[String] = { - rdd.pipe(command.asScala, env.asScala, null, null, separateWorkingDir, bufferSize) + rdd.pipe(command.asScala.toSeq, env.asScala, null, null, separateWorkingDir, bufferSize) } /** @@ -293,7 +293,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { separateWorkingDir: Boolean, bufferSize: Int, encoding: String): JavaRDD[String] = { - rdd.pipe(command.asScala, env.asScala, null, null, separateWorkingDir, bufferSize, encoding) + rdd.pipe(command.asScala.toSeq, env.asScala, null, null, separateWorkingDir, bufferSize, + encoding) } /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 149def29b8fbd..39eb1ee731d50 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -74,7 +74,7 @@ class JavaSparkContext(val sc: SparkContext) extends Closeable { /** * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). * @param appName A name for your application, to display on the cluster web UI - * @param sparkHome The SPARK_HOME directory on the slave nodes + * @param sparkHome The SPARK_HOME directory on the worker nodes * @param jarFile JAR file to send to the cluster. This can be a path on the local file system * or an HDFS, HTTP, HTTPS, or FTP URL. */ @@ -84,7 +84,7 @@ class JavaSparkContext(val sc: SparkContext) extends Closeable { /** * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). * @param appName A name for your application, to display on the cluster web UI - * @param sparkHome The SPARK_HOME directory on the slave nodes + * @param sparkHome The SPARK_HOME directory on the worker nodes * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. */ @@ -94,7 +94,7 @@ class JavaSparkContext(val sc: SparkContext) extends Closeable { /** * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). * @param appName A name for your application, to display on the cluster web UI - * @param sparkHome The SPARK_HOME directory on the slave nodes + * @param sparkHome The SPARK_HOME directory on the worker nodes * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. * @param environment Environment variables to set on worker nodes @@ -133,7 +133,7 @@ class JavaSparkContext(val sc: SparkContext) extends Closeable { /** Distribute a local Scala collection to form an RDD. */ def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { implicit val ctag: ClassTag[T] = fakeClassTag - sc.parallelize(list.asScala, numSlices) + sc.parallelize(list.asScala.toSeq, numSlices) } /** Get an RDD that has no partitions or elements. */ @@ -152,7 +152,7 @@ class JavaSparkContext(val sc: SparkContext) extends Closeable { : JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = fakeClassTag implicit val ctagV: ClassTag[V] = fakeClassTag - JavaPairRDD.fromRDD(sc.parallelize(list.asScala, numSlices)) + JavaPairRDD.fromRDD(sc.parallelize(list.asScala.toSeq, numSlices)) } /** Distribute a local Scala collection to form an RDD. */ @@ -161,7 +161,7 @@ class JavaSparkContext(val sc: SparkContext) extends Closeable { /** Distribute a local Scala collection to form an RDD. */ def parallelizeDoubles(list: java.util.List[java.lang.Double], numSlices: Int): JavaDoubleRDD = - JavaDoubleRDD.fromRDD(sc.parallelize(list.asScala.map(_.doubleValue()), numSlices)) + JavaDoubleRDD.fromRDD(sc.parallelize(list.asScala.map(_.doubleValue()).toSeq, numSlices)) /** Distribute a local Scala collection to form an RDD. */ def parallelizeDoubles(list: java.util.List[java.lang.Double]): JavaDoubleRDD = diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 726cff6703dcb..86a1ac31c0845 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -163,7 +163,7 @@ private[spark] object PythonRDD extends Logging { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = - sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions.asScala) + sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions.asScala.toSeq) val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*) serveIterator(flattenedPartition.iterator, s"serve RDD ${rdd.id} with partitions ${partitions.asScala.mkString(",")}") diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 490b48719b6be..527d0d6d3a48d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -54,7 +54,7 @@ private[spark] object PythonUtils { * Convert list of T into seq of T (for calling API with varargs) */ def toSeq[T](vs: JList[T]): Seq[T] = { - vs.asScala + vs.asScala.toSeq } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 01e64b6972ae2..5a6fa507963f0 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -45,71 +45,6 @@ private[spark] object SerDeUtil extends Logging { } } } - // Unpickle array.array generated by Python 2.6 - class ArrayConstructor extends net.razorvine.pickle.objects.ArrayConstructor { - // /* Description of types */ - // static struct arraydescr descriptors[] = { - // {'c', sizeof(char), c_getitem, c_setitem}, - // {'b', sizeof(char), b_getitem, b_setitem}, - // {'B', sizeof(char), BB_getitem, BB_setitem}, - // #ifdef Py_USING_UNICODE - // {'u', sizeof(Py_UNICODE), u_getitem, u_setitem}, - // #endif - // {'h', sizeof(short), h_getitem, h_setitem}, - // {'H', sizeof(short), HH_getitem, HH_setitem}, - // {'i', sizeof(int), i_getitem, i_setitem}, - // {'I', sizeof(int), II_getitem, II_setitem}, - // {'l', sizeof(long), l_getitem, l_setitem}, - // {'L', sizeof(long), LL_getitem, LL_setitem}, - // {'f', sizeof(float), f_getitem, f_setitem}, - // {'d', sizeof(double), d_getitem, d_setitem}, - // {'\0', 0, 0, 0} /* Sentinel */ - // }; - val machineCodes: Map[Char, Int] = if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) { - Map('B' -> 0, 'b' -> 1, 'H' -> 3, 'h' -> 5, 'I' -> 7, 'i' -> 9, - 'L' -> 11, 'l' -> 13, 'f' -> 15, 'd' -> 17, 'u' -> 21 - ) - } else { - Map('B' -> 0, 'b' -> 1, 'H' -> 2, 'h' -> 4, 'I' -> 6, 'i' -> 8, - 'L' -> 10, 'l' -> 12, 'f' -> 14, 'd' -> 16, 'u' -> 20 - ) - } - override def construct(args: Array[Object]): Object = { - if (args.length == 1) { - construct(args ++ Array("")) - } else if (args.length == 2 && args(1).isInstanceOf[String]) { - val typecode = args(0).asInstanceOf[String].charAt(0) - // This must be ISO 8859-1 / Latin 1, not UTF-8, to interoperate correctly - val data = args(1).asInstanceOf[String].getBytes(StandardCharsets.ISO_8859_1) - if (typecode == 'c') { - // It seems like the pickle of pypy uses the similar protocol to Python 2.6, which uses - // a string for array data instead of list as Python 2.7, and handles an array of - // typecode 'c' as 1-byte character. - val result = new Array[Char](data.length) - var i = 0 - while (i < data.length) { - result(i) = data(i).toChar - i += 1 - } - result - } else { - construct(typecode, machineCodes(typecode), data) - } - } else if (args.length == 2 && args(0) == "l") { - // On Python 2, an array of typecode 'l' should be handled as long rather than int. - val values = args(1).asInstanceOf[JArrayList[_]] - val result = new Array[Long](values.size) - var i = 0 - while (i < values.size) { - result(i) = values.get(i).asInstanceOf[Number].longValue() - i += 1 - } - result - } else { - super.construct(args) - } - } - } private var initialized = false // This should be called before trying to unpickle array.array from Python @@ -117,7 +52,6 @@ private[spark] object SerDeUtil extends Logging { def initialize(): Unit = { synchronized{ if (!initialized) { - Unpickler.registerConstructor("array", "array", new ArrayConstructor()) Unpickler.registerConstructor("__builtin__", "bytearray", new ByteArrayConstructor()) Unpickler.registerConstructor("builtins", "bytearray", new ByteArrayConstructor()) Unpickler.registerConstructor("__builtin__", "bytes", new ByteArrayConstructor()) diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 6ff68b694f8f3..ab389f99b11a7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -205,7 +205,7 @@ private object FaultToleranceTest extends App with Logging { private def addWorkers(num: Int): Unit = { logInfo(s">>>>> ADD WORKERS $num <<<<<") - val masterUrls = getMasterUrls(masters) + val masterUrls = getMasterUrls(masters.toSeq) (1 to num).foreach { _ => workers += SparkDocker.startWorker(dockerMountDir, masterUrls) } } @@ -216,7 +216,7 @@ private object FaultToleranceTest extends App with Logging { // Counter-hack: Because of a hack in SparkEnv#create() that changes this // property, we need to reset it. System.setProperty(config.DRIVER_PORT.key, "0") - sc = new SparkContext(getMasterUrls(masters), "fault-tolerance", containerSparkHome) + sc = new SparkContext(getMasterUrls(masters.toSeq), "fault-tolerance", containerSparkHome) } private def getMasterUrls(masters: Seq[TestMasterInfo]): String = { @@ -279,7 +279,7 @@ private object FaultToleranceTest extends App with Logging { var liveWorkerIPs: Seq[String] = List() def stateValid(): Boolean = { - (workers.map(_.ip) -- liveWorkerIPs).isEmpty && + workers.map(_.ip).forall(liveWorkerIPs.contains) && numAlive == 1 && numStandby == masters.size - 1 && numLiveApps >= 1 } diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 6c3276c5c790a..17733d99cd5bc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -90,11 +90,12 @@ private[deploy] object JsonProtocol { * `name` the description of the application * `cores` total cores granted to the application * `user` name of the user who submitted the application - * `memoryperslave` minimal memory in MB required to each executor - * `resourcesperslave` minimal resources required to each executor + * `memoryperexecutor` minimal memory in MB required to each executor + * `resourcesperexecutor` minimal resources required to each executor * `submitdate` time in Date that the application is submitted * `state` state of the application, see [[ApplicationState]] * `duration` time in milliseconds that the application has been running + * For compatibility also returns the deprecated `memoryperslave` & `resourcesperslave` fields. */ def writeApplicationInfo(obj: ApplicationInfo): JObject = { ("id" -> obj.id) ~ @@ -102,7 +103,10 @@ private[deploy] object JsonProtocol { ("name" -> obj.desc.name) ~ ("cores" -> obj.coresGranted) ~ ("user" -> obj.desc.user) ~ + ("memoryperexecutor" -> obj.desc.memoryPerExecutorMB) ~ ("memoryperslave" -> obj.desc.memoryPerExecutorMB) ~ + ("resourcesperexecutor" -> obj.desc.resourceReqsPerExecutor + .toList.map(writeResourceRequirement)) ~ ("resourcesperslave" -> obj.desc.resourceReqsPerExecutor .toList.map(writeResourceRequirement)) ~ ("submitdate" -> obj.submitDate.toString) ~ @@ -117,14 +121,17 @@ private[deploy] object JsonProtocol { * @return a Json object containing the following fields: * `name` the description of the application * `cores` max cores that can be allocated to the application, 0 means unlimited - * `memoryperslave` minimal memory in MB required to each executor - * `resourcesperslave` minimal resources required to each executor + * `memoryperexecutor` minimal memory in MB required to each executor + * `resourcesperexecutor` minimal resources required to each executor * `user` name of the user who submitted the application * `command` the command string used to submit the application + * For compatibility also returns the deprecated `memoryperslave` & `resourcesperslave` fields. */ def writeApplicationDescription(obj: ApplicationDescription): JObject = { ("name" -> obj.name) ~ ("cores" -> obj.maxCores.getOrElse(0)) ~ + ("memoryperexecutor" -> obj.memoryPerExecutorMB) ~ + ("resourcesperexecutor" -> obj.resourceReqsPerExecutor.toList.map(writeResourceRequirement)) ~ ("memoryperslave" -> obj.memoryPerExecutorMB) ~ ("resourcesperslave" -> obj.resourceReqsPerExecutor.toList.map(writeResourceRequirement)) ~ ("user" -> obj.user) ~ diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 574ce60b19b4e..7ad92da4e055a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -69,7 +69,7 @@ object PythonRunner { pathElements ++= formattedPyFiles pathElements += PythonUtils.sparkPythonPath pathElements += sys.env.getOrElse("PYTHONPATH", "") - val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*) + val pythonPath = PythonUtils.mergePythonPaths(pathElements.toSeq: _*) // Launch Python process val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 1271a3dbfc3f6..6d38a1d281464 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -820,7 +820,7 @@ private[spark] class SparkSubmit extends Logging { } sparkConf.set(SUBMIT_PYTHON_FILES, formattedPyFiles.split(",").toSeq) - (childArgs, childClasspath, sparkConf, childMainClass) + (childArgs.toSeq, childClasspath.toSeq, sparkConf, childMainClass) } private def renameResourcesToLocalFS(resources: String, localResources: String): String = { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 25ea75acc37d3..a73a5e9463204 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -128,6 +128,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private val storePath = conf.get(LOCAL_STORE_DIR).map(new File(_)) private val fastInProgressParsing = conf.get(FAST_IN_PROGRESS_PARSING) + private val hybridStoreEnabled = conf.get(History.HYBRID_STORE_ENABLED) + // Visible for testing. private[history] val listing: KVStore = storePath.map { path => val dbPath = Files.createDirectories(new File(path, "listing.ldb").toPath()).toFile() @@ -158,6 +160,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) new HistoryServerDiskManager(conf, path, listing, clock) } + private var memoryManager: HistoryServerMemoryManager = null + if (hybridStoreEnabled) { + memoryManager = new HistoryServerMemoryManager(conf) + } + private val fileCompactor = new EventLogFileCompactor(conf, hadoopConf, fs, conf.get(EVENT_LOG_ROLLING_MAX_FILES_TO_RETAIN), conf.get(EVENT_LOG_COMPACTION_SCORE_THRESHOLD)) @@ -181,23 +188,24 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) processing.remove(path.getName) } - private val blacklist = new ConcurrentHashMap[String, Long] + private val inaccessibleList = new ConcurrentHashMap[String, Long] // Visible for testing - private[history] def isBlacklisted(path: Path): Boolean = { - blacklist.containsKey(path.getName) + private[history] def isAccessible(path: Path): Boolean = { + !inaccessibleList.containsKey(path.getName) } - private def blacklist(path: Path): Unit = { - blacklist.put(path.getName, clock.getTimeMillis()) + private def markInaccessible(path: Path): Unit = { + inaccessibleList.put(path.getName, clock.getTimeMillis()) } /** - * Removes expired entries in the blacklist, according to the provided `expireTimeInSeconds`. + * Removes expired entries in the inaccessibleList, according to the provided + * `expireTimeInSeconds`. */ - private def clearBlacklist(expireTimeInSeconds: Long): Unit = { + private def clearInaccessibleList(expireTimeInSeconds: Long): Unit = { val expiredThreshold = clock.getTimeMillis() - expireTimeInSeconds * 1000 - blacklist.asScala.retain((_, creationTime) => creationTime >= expiredThreshold) + inaccessibleList.asScala.retain((_, creationTime) => creationTime >= expiredThreshold) } private val activeUIs = new mutable.HashMap[(String, Option[String]), LoadedAppUI]() @@ -262,6 +270,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private def startPolling(): Unit = { diskManager.foreach(_.initialize()) + if (memoryManager != null) { + memoryManager.initialize() + } // Validate the log directory. val path = new Path(logDir) @@ -460,7 +471,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) logDebug(s"Scanning $logDir with lastScanTime==$lastScanTime") val updated = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) - .filter { entry => !isBlacklisted(entry.getPath) } + .filter { entry => isAccessible(entry.getPath) } .filter { entry => !isProcessing(entry.getPath) } .flatMap { entry => EventLogFileReader(fs, entry) } .filter { reader => @@ -677,8 +688,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) case e: AccessControlException => // We don't have read permissions on the log file logWarning(s"Unable to read log $rootPath", e) - blacklist(rootPath) - // SPARK-28157 We should remove this blacklisted entry from the KVStore + markInaccessible(rootPath) + // SPARK-28157 We should remove this inaccessible entry from the KVStore // to handle permission-only changes with the same file sizes later. listing.delete(classOf[LogInfo], rootPath.toString) case e: Exception => @@ -946,8 +957,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - // Clean the blacklist from the expired entries. - clearBlacklist(CLEAN_INTERVAL_S) + // Clean the inaccessibleList from the expired entries. + clearInaccessibleList(CLEAN_INTERVAL_S) } private def deleteAttemptLogs( @@ -1167,6 +1178,95 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // At this point the disk data either does not exist or was deleted because it failed to // load, so the event log needs to be replayed. + // If the hybrid store is enabled, try it first and fail back to leveldb store. + if (hybridStoreEnabled) { + try { + return createHybridStore(dm, appId, attempt, metadata) + } catch { + case e: Exception => + logInfo(s"Failed to create HybridStore for $appId/${attempt.info.attemptId}." + + " Using LevelDB.", e) + } + } + + createLevelDBStore(dm, appId, attempt, metadata) + } + + private def createHybridStore( + dm: HistoryServerDiskManager, + appId: String, + attempt: AttemptInfoWrapper, + metadata: AppStatusStoreMetadata): KVStore = { + var retried = false + var hybridStore: HybridStore = null + val reader = EventLogFileReader(fs, new Path(logDir, attempt.logPath), + attempt.lastIndex) + + // Use InMemoryStore to rebuild app store + while (hybridStore == null) { + // A RuntimeException will be thrown if the heap memory is not sufficient + memoryManager.lease(appId, attempt.info.attemptId, reader.totalSize, + reader.compressionCodec) + var store: HybridStore = null + try { + store = new HybridStore() + rebuildAppStore(store, reader, attempt.info.lastUpdated.getTime()) + hybridStore = store + } catch { + case _: IOException if !retried => + // compaction may touch the file(s) which app rebuild wants to read + // compaction wouldn't run in short interval, so try again... + logWarning(s"Exception occurred while rebuilding log path ${attempt.logPath} - " + + "trying again...") + store.close() + memoryManager.release(appId, attempt.info.attemptId) + retried = true + case e: Exception => + store.close() + memoryManager.release(appId, attempt.info.attemptId) + throw e + } + } + + // Create a LevelDB and start a background thread to dump data to LevelDB + var lease: dm.Lease = null + try { + logInfo(s"Leasing disk manager space for app $appId / ${attempt.info.attemptId}...") + lease = dm.lease(reader.totalSize, reader.compressionCodec.isDefined) + val levelDB = KVUtils.open(lease.tmpPath, metadata) + hybridStore.setLevelDB(levelDB) + hybridStore.switchToLevelDB(new HybridStore.SwitchToLevelDBListener { + override def onSwitchToLevelDBSuccess: Unit = { + logInfo(s"Completely switched to LevelDB for app $appId / ${attempt.info.attemptId}.") + levelDB.close() + val newStorePath = lease.commit(appId, attempt.info.attemptId) + hybridStore.setLevelDB(KVUtils.open(newStorePath, metadata)) + memoryManager.release(appId, attempt.info.attemptId) + } + override def onSwitchToLevelDBFail(e: Exception): Unit = { + logWarning(s"Failed to switch to LevelDB for app $appId / ${attempt.info.attemptId}", e) + levelDB.close() + lease.rollback() + } + }, appId, attempt.info.attemptId) + } catch { + case e: Exception => + hybridStore.close() + memoryManager.release(appId, attempt.info.attemptId) + if (lease != null) { + lease.rollback() + } + throw e + } + + hybridStore + } + + private def createLevelDBStore( + dm: HistoryServerDiskManager, + appId: String, + attempt: AttemptInfoWrapper, + metadata: AppStatusStoreMetadata): KVStore = { var retried = false var newStorePath: File = null while (newStorePath == null) { @@ -1235,7 +1335,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private def deleteLog(fs: FileSystem, log: Path): Boolean = { var deleted = false - if (isBlacklisted(log)) { + if (!isAccessible(log)) { logDebug(s"Skipping deleting $log as we don't have permissions on it.") } else { try { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index aa9e9a6dd4887..ca21a8056d1b5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -76,9 +76,7 @@ class HistoryServer( // attempt ID (separated by a slash). val parts = Option(req.getPathInfo()).getOrElse("").split("/") if (parts.length < 2) { - res.sendError(HttpServletResponse.SC_BAD_REQUEST, - s"Unexpected path info in request (URI = ${req.getRequestURI()}") - return + res.sendRedirect("/") } val appId = parts(1) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala index b1adc3c112ed3..31f9d185174dc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala @@ -75,14 +75,29 @@ private class HistoryServerDiskManager( // Go through the recorded store directories and remove any that may have been removed by // external code. - val orphans = listing.view(classOf[ApplicationStoreInfo]).asScala.filter { info => - !new File(info.path).exists() - }.toSeq + val (existences, orphans) = listing + .view(classOf[ApplicationStoreInfo]) + .asScala + .toSeq + .partition { info => + new File(info.path).exists() + } orphans.foreach { info => listing.delete(info.getClass(), info.path) } + // Reading level db would trigger table file compaction, then it may cause size of level db + // directory changed. When service restarts, "currentUsage" is calculated from real directory + // size. Update "ApplicationStoreInfo.size" to ensure "currentUsage" equals + // sum of "ApplicationStoreInfo.size". + existences.foreach { info => + val fileSize = sizeOf(new File(info.path)) + if (fileSize != info.size) { + listing.write(info.copy(size = fileSize)) + } + } + logInfo("Initialized disk manager: " + s"current usage = ${Utils.bytesToString(currentUsage.get())}, " + s"max usage = ${Utils.bytesToString(maxUsage)}") @@ -235,7 +250,7 @@ private class HistoryServerDiskManager( } } - private def appStorePath(appId: String, attemptId: Option[String]): File = { + private[history] def appStorePath(appId: String, attemptId: Option[String]): File = { val fileName = appId + attemptId.map("_" + _).getOrElse("") + ".ldb" new File(appStoreDir, fileName) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerMemoryManager.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerMemoryManager.scala new file mode 100644 index 0000000000000..7fc0722233854 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerMemoryManager.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable.HashMap + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.History._ +import org.apache.spark.util.Utils + +/** + * A class used to keep track of in-memory store usage by the SHS. + */ +private class HistoryServerMemoryManager( + conf: SparkConf) extends Logging { + + private val maxUsage = conf.get(MAX_IN_MEMORY_STORE_USAGE) + private val currentUsage = new AtomicLong(0L) + private val active = new HashMap[(String, Option[String]), Long]() + + def initialize(): Unit = { + logInfo("Initialized memory manager: " + + s"current usage = ${Utils.bytesToString(currentUsage.get())}, " + + s"max usage = ${Utils.bytesToString(maxUsage)}") + } + + def lease( + appId: String, + attemptId: Option[String], + eventLogSize: Long, + codec: Option[String]): Unit = { + val memoryUsage = approximateMemoryUsage(eventLogSize, codec) + if (memoryUsage + currentUsage.get > maxUsage) { + throw new RuntimeException("Not enough memory to create hybrid store " + + s"for app $appId / $attemptId.") + } + active.synchronized { + active(appId -> attemptId) = memoryUsage + } + currentUsage.addAndGet(memoryUsage) + logInfo(s"Leasing ${Utils.bytesToString(memoryUsage)} memory usage for " + + s"app $appId / $attemptId") + } + + def release(appId: String, attemptId: Option[String]): Unit = { + val memoryUsage = active.synchronized { active.remove(appId -> attemptId) } + + memoryUsage match { + case Some(m) => + currentUsage.addAndGet(-m) + logInfo(s"Released ${Utils.bytesToString(m)} memory usage for " + + s"app $appId / $attemptId") + case None => + } + } + + private def approximateMemoryUsage(eventLogSize: Long, codec: Option[String]): Long = { + codec match { + case Some("zstd") => + eventLogSize * 10 + case Some(_) => + eventLogSize * 4 + case None => + eventLogSize / 2 + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HybridStore.scala b/core/src/main/scala/org/apache/spark/deploy/history/HybridStore.scala new file mode 100644 index 0000000000000..96db86f8e745a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/history/HybridStore.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import java.io.IOException +import java.util.Collection +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ + +import org.apache.spark.util.kvstore._ + +/** + * An implementation of KVStore that accelerates event logs loading. + * + * When rebuilding the application state from event logs, HybridStore will + * write data to InMemoryStore at first and use a background thread to dump + * data to LevelDB once the app store is restored. We don't expect write + * operations (except the case for caching) after calling switch to level DB. + */ + +private[history] class HybridStore extends KVStore { + + private val inMemoryStore = new InMemoryStore() + + private var levelDB: LevelDB = null + + // Flag to indicate whether we should use inMemoryStore or levelDB + private val shouldUseInMemoryStore = new AtomicBoolean(true) + + // Flag to indicate whether this hybrid store is closed, use this flag + // to avoid starting background thread after the store is closed + private val closed = new AtomicBoolean(false) + + // A background thread that dumps data from inMemoryStore to levelDB + private var backgroundThread: Thread = null + + // A hash map that stores all classes that had been writen to inMemoryStore + private val klassMap = new ConcurrentHashMap[Class[_], Boolean] + + override def getMetadata[T](klass: Class[T]): T = { + getStore().getMetadata(klass) + } + + override def setMetadata(value: Object): Unit = { + getStore().setMetadata(value) + } + + override def read[T](klass: Class[T], naturalKey: Object): T = { + getStore().read(klass, naturalKey) + } + + override def write(value: Object): Unit = { + getStore().write(value) + + if (backgroundThread == null) { + // New classes won't be dumped once the background thread is started + klassMap.putIfAbsent(value.getClass(), true) + } + } + + override def delete(klass: Class[_], naturalKey: Object): Unit = { + if (backgroundThread != null) { + throw new IllegalStateException("delete() shouldn't be called after " + + "the hybrid store begins switching to levelDB") + } + + getStore().delete(klass, naturalKey) + } + + override def view[T](klass: Class[T]): KVStoreView[T] = { + getStore().view(klass) + } + + override def count(klass: Class[_]): Long = { + getStore().count(klass) + } + + override def count(klass: Class[_], index: String, indexedValue: Object): Long = { + getStore().count(klass, index, indexedValue) + } + + override def close(): Unit = { + try { + closed.set(true) + if (backgroundThread != null && backgroundThread.isAlive()) { + // The background thread is still running, wait for it to finish + backgroundThread.join() + } + } finally { + inMemoryStore.close() + if (levelDB != null) { + levelDB.close() + } + } + } + + override def removeAllByIndexValues[T]( + klass: Class[T], + index: String, + indexValues: Collection[_]): Boolean = { + if (backgroundThread != null) { + throw new IllegalStateException("removeAllByIndexValues() shouldn't be " + + "called after the hybrid store begins switching to levelDB") + } + + getStore().removeAllByIndexValues(klass, index, indexValues) + } + + def setLevelDB(levelDB: LevelDB): Unit = { + this.levelDB = levelDB + } + + /** + * This method is called when the writing is done for inMemoryStore. A + * background thread will be created and be started to dump data in inMemoryStore + * to levelDB. Once the dumping is completed, the underlying kvstore will be + * switched to levelDB. + */ + def switchToLevelDB( + listener: HybridStore.SwitchToLevelDBListener, + appId: String, + attemptId: Option[String]): Unit = { + if (closed.get) { + return + } + + backgroundThread = new Thread(() => { + try { + for (klass <- klassMap.keys().asScala) { + val it = inMemoryStore.view(klass).closeableIterator() + while (it.hasNext()) { + levelDB.write(it.next()) + } + } + listener.onSwitchToLevelDBSuccess() + shouldUseInMemoryStore.set(false) + inMemoryStore.close() + } catch { + case e: Exception => + listener.onSwitchToLevelDBFail(e) + } + }) + backgroundThread.setDaemon(true) + backgroundThread.setName(s"hybridstore-$appId-$attemptId") + backgroundThread.start() + } + + /** + * This method return the store that we should use. + */ + private def getStore(): KVStore = { + if (shouldUseInMemoryStore.get) { + inMemoryStore + } else { + levelDB + } + } +} + +private[history] object HybridStore { + + trait SwitchToLevelDBListener { + + def onSwitchToLevelDBSuccess(): Unit + + def onSwitchToLevelDBFail(e: Exception): Unit + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 8eae445b439d9..ded816b992db8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -52,7 +52,7 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer override def read[T: ClassTag](prefix: String): Seq[T] = { zk.getChildren.forPath(workingDir).asScala - .filter(_.startsWith(prefix)).flatMap(deserializeFromFile[T]) + .filter(_.startsWith(prefix)).flatMap(deserializeFromFile[T]).toSeq } override def close(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 1648ba516d9b6..cc1d60a097b2e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -411,7 +411,7 @@ private[spark] object RestSubmissionClient { // SPARK_HOME and SPARK_CONF_DIR are filtered out because they are usually wrong // on the remote machine (SPARK-12345) (SPARK-25934) - private val BLACKLISTED_SPARK_ENV_VARS = Set("SPARK_ENV_LOADED", "SPARK_HOME", "SPARK_CONF_DIR") + private val EXCLUDED_SPARK_ENV_VARS = Set("SPARK_ENV_LOADED", "SPARK_HOME", "SPARK_CONF_DIR") private val REPORT_DRIVER_STATUS_INTERVAL = 1000 private val REPORT_DRIVER_STATUS_MAX_TRIES = 10 val PROTOCOL_VERSION = "v1" @@ -421,8 +421,8 @@ private[spark] object RestSubmissionClient { */ private[rest] def filterSystemEnvironment(env: Map[String, String]): Map[String, String] = { env.filterKeys { k => - (k.startsWith("SPARK_") && !BLACKLISTED_SPARK_ENV_VARS.contains(k)) || k.startsWith("MESOS_") - } + (k.startsWith("SPARK_") && !EXCLUDED_SPARK_ENV_VARS.contains(k)) || k.startsWith("MESOS_") + }.toMap } private[spark] def supportsRestClient(master: String): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index f7423f1fc3f1c..8240bd6d2f438 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -61,7 +61,7 @@ object CommandUtils extends Logging { // SPARK-698: do not call the run.cmd script, as process.destroy() // fails to kill a process tree on Windows val cmd = new WorkerCommandBuilder(sparkHome, memory, command).buildCommand() - cmd.asScala ++ Seq(command.mainClass) ++ command.arguments + (cmd.asScala ++ Seq(command.mainClass) ++ command.arguments).toSeq } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 53ec7b3a88f35..4f9c497fc3d76 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -201,7 +201,7 @@ private[deploy] class DriverRunner( CommandUtils.redirectStream(process.getInputStream, stdout) val stderr = new File(baseDir, "stderr") - val redactedCommand = Utils.redactCommandLineArgs(conf, builder.command.asScala) + val redactedCommand = Utils.redactCommandLineArgs(conf, builder.command.asScala.toSeq) .mkString("\"", "\" \"", "\"") val header = "Launch Command: %s\n%s\n\n".format(redactedCommand, "=" * 40) Files.append(header, stderr, StandardCharsets.UTF_8) @@ -262,6 +262,6 @@ private[deploy] trait ProcessBuilderLike { private[deploy] object ProcessBuilderLike { def apply(processBuilder: ProcessBuilder): ProcessBuilderLike = new ProcessBuilderLike { override def start(): Process = processBuilder.start() - override def command: Seq[String] = processBuilder.command().asScala + override def command: Seq[String] = processBuilder.command().asScala.toSeq } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 2a5528bbe89cb..e4fcae13a2f89 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -158,7 +158,7 @@ private[deploy] class ExecutorRunner( val builder = CommandUtils.buildProcessBuilder(subsCommand, new SecurityManager(conf), memory, sparkHome.getAbsolutePath, substituteVariables) val command = builder.command() - val redactedCommand = Utils.redactCommandLineArgs(conf, command.asScala) + val redactedCommand = Utils.redactCommandLineArgs(conf, command.asScala.toSeq) .mkString("\"", "\" \"", "\"") logInfo(s"Launch command: $redactedCommand") diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 6625457749f6a..e072d7919450e 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -140,13 +140,13 @@ private[spark] class CoarseGrainedExecutorBackend( def extractLogUrls: Map[String, String] = { val prefix = "SPARK_LOG_URL_" sys.env.filterKeys(_.startsWith(prefix)) - .map(e => (e._1.substring(prefix.length).toLowerCase(Locale.ROOT), e._2)) + .map(e => (e._1.substring(prefix.length).toLowerCase(Locale.ROOT), e._2)).toMap } def extractAttributes: Map[String, String] = { val prefix = "SPARK_EXECUTOR_ATTRIBUTE_" sys.env.filterKeys(_.startsWith(prefix)) - .map(e => (e._1.substring(prefix.length).toUpperCase(Locale.ROOT), e._2)) + .map(e => (e._1.substring(prefix.length).toUpperCase(Locale.ROOT), e._2)).toMap } override def receive: PartialFunction[Any, Unit] = { @@ -304,8 +304,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val createFn: (RpcEnv, Arguments, SparkEnv, ResourceProfile) => CoarseGrainedExecutorBackend = { case (rpcEnv, arguments, env, resourceProfile) => new CoarseGrainedExecutorBackend(rpcEnv, arguments.driverUrl, arguments.executorId, - arguments.bindAddress, arguments.hostname, arguments.cores, arguments.userClassPath, env, - arguments.resourcesFileOpt, resourceProfile) + arguments.bindAddress, arguments.hostname, arguments.cores, arguments.userClassPath.toSeq, + env, arguments.resourcesFileOpt, resourceProfile) } run(parseArguments(args, this.getClass.getCanonicalName.stripSuffix("$")), createFn) System.exit(0) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c8b1afeebac0d..bc0f0c0a7b705 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -606,7 +606,8 @@ private[spark] class Executor( // Here and below, put task metric peaks in a WrappedArray to expose them as a Seq // without requiring a copy. val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId)) - val serializedTK = ser.serialize(TaskKilled(t.reason, accUpdates, accums, metricPeaks)) + val serializedTK = ser.serialize( + TaskKilled(t.reason, accUpdates, accums, metricPeaks.toSeq)) execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) case _: InterruptedException | NonFatal(_) if @@ -616,7 +617,8 @@ private[spark] class Executor( val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs) val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId)) - val serializedTK = ser.serialize(TaskKilled(killReason, accUpdates, accums, metricPeaks)) + val serializedTK = ser.serialize( + TaskKilled(killReason, accUpdates, accums, metricPeaks.toSeq)) execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => @@ -661,13 +663,13 @@ private[spark] class Executor( val serializedTaskEndReason = { try { val ef = new ExceptionFailure(t, accUpdates).withAccums(accums) - .withMetricPeaks(metricPeaks) + .withMetricPeaks(metricPeaks.toSeq) ser.serialize(ef) } catch { case _: NotSerializableException => // t is not serializable so just send the stacktrace val ef = new ExceptionFailure(t, accUpdates, false).withAccums(accums) - .withMetricPeaks(metricPeaks) + .withMetricPeaks(metricPeaks.toSeq) ser.serialize(ef) } } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 1470a23884bb0..43742a4d46cbb 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -123,7 +123,7 @@ class TaskMetrics private[spark] () extends Serializable { def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = { // This is called on driver. All accumulator updates have a fixed value. So it's safe to use // `asScala` which accesses the internal values using `java.util.Iterator`. - _updatedBlockStatuses.value.asScala + _updatedBlockStatuses.value.asScala.toSeq } // Setters and increment-ers @@ -199,7 +199,7 @@ class TaskMetrics private[spark] () extends Serializable { */ private[spark] def mergeShuffleReadMetrics(): Unit = synchronized { if (tempShuffleReadMetrics.nonEmpty) { - shuffleReadMetrics.setMergeValues(tempShuffleReadMetrics) + shuffleReadMetrics.setMergeValues(tempShuffleReadMetrics.toSeq) } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/History.scala b/core/src/main/scala/org/apache/spark/internal/config/History.scala index 581777de366ef..a6d1c044130f5 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/History.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/History.scala @@ -195,4 +195,20 @@ private[spark] object History { .version("3.0.0") .booleanConf .createWithDefault(true) + + val HYBRID_STORE_ENABLED = ConfigBuilder("spark.history.store.hybridStore.enabled") + .doc("Whether to use HybridStore as the store when parsing event logs. " + + "HybridStore will first write data to an in-memory store and having a background thread " + + "that dumps data to a disk store after the writing to in-memory store is completed.") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + + val MAX_IN_MEMORY_STORE_USAGE = ConfigBuilder("spark.history.store.hybridStore.maxMemoryUsage") + .doc("Maximum memory space that can be used to create HybridStore. The HybridStore co-uses " + + "the heap memory, so the heap memory should be increased through the memory option for SHS " + + "if the HybridStore is enabled.") + .version("3.1.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("2g") } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index ee437c696b47e..ca75a19af7bf6 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -459,9 +459,10 @@ package object config { .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("60s") - private[spark] val STORAGE_BLOCKMANAGER_SLAVE_TIMEOUT = - ConfigBuilder("spark.storage.blockManagerSlaveTimeoutMs") + private[spark] val STORAGE_BLOCKMANAGER_HEARTBEAT_TIMEOUT = + ConfigBuilder("spark.storage.blockManagerHeartbeatTimeoutMs") .version("0.7.0") + .withAlternative("spark.storage.blockManagerSlaveTimeoutMs") .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString(Network.NETWORK_TIMEOUT.defaultValueString) diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 57dcbe501c6dd..48f816f649d36 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -156,7 +156,7 @@ private[spark] class MetricsSystem private ( } def getSourcesByName(sourceName: String): Seq[Source] = - sources.filter(_.sourceName == sourceName) + sources.filter(_.sourceName == sourceName).toSeq def registerSource(source: Source): Unit = { sources += source diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 9742d12cfe01e..d5f21112c0c9e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -81,7 +81,7 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * @param sc The SparkContext to associate the RDD with. * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed * variable references an instance of JobConf, then that JobConf will be used for the Hadoop job. - * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration. + * Otherwise, a new JobConf will be created on each executor using the enclosed Configuration. * @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HadoopRDD * creates. * @param inputFormatClass Storage format of the data to be read. @@ -140,7 +140,7 @@ class HadoopRDD[K, V]( private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) - // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. + // Returns a JobConf that will be used on executors to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { val conf: Configuration = broadcastedConf.value.value if (shouldCloneJobConf) { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 1e39e10856877..f280c220a2c8d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -934,7 +934,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) for (pair <- it if pair._1 == key) { buf += pair._2 } - buf + buf.toSeq } : Seq[V] val res = self.context.runJob(self, process, Array(index)) res(0) diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 9f8019b80a4dd..324cba5b4de42 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -133,12 +133,11 @@ private object ParallelCollectionRDD { // If the range is inclusive, use inclusive range for the last slice if (r.isInclusive && index == numSlices - 1) { new Range.Inclusive(r.start + start * r.step, r.end, r.step) - } - else { - new Range(r.start + start * r.step, r.start + end * r.step, r.step) + } else { + new Range.Inclusive(r.start + start * r.step, r.start + (end - 1) * r.step, r.step) } }.toSeq.asInstanceOf[Seq[Seq[T]]] - case nr: NumericRange[_] => + case nr: NumericRange[T] => // For ranges of Long, Double, BigInteger, etc val slices = new ArrayBuffer[Seq[T]](numSlices) var r = nr @@ -147,7 +146,7 @@ private object ParallelCollectionRDD { slices += r.take(sliceSize).asInstanceOf[Seq[T]] r = r.drop(sliceSize) } - slices + slices.toSeq case _ => val array = seq.toArray // To prevent O(n^2) operations for List etc positions(array.length, numSlices).map { case (start, end) => diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 3b11e82dab196..5dd8cb8440be6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -238,7 +238,7 @@ private object PipedRDD { while(tok.hasMoreElements) { buf += tok.nextToken() } - buf + buf.toSeq } val STDIN_WRITER_THREAD_PREFIX = "stdin writer for" diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 63fa3c2487c33..0a93023443704 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -98,7 +98,7 @@ class UnionRDD[T: ClassTag]( deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length) pos += rdd.partitions.length } - deps + deps.toSeq } override def compute(s: Partition, context: TaskContext): Iterator[T] = { diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala index 1dbdc3d81e44d..f56ea69f6cec5 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala @@ -319,12 +319,13 @@ object ResourceProfile extends Logging { private[spark] def getCustomTaskResources( rp: ResourceProfile): Map[String, TaskResourceRequest] = { - rp.taskResources.filterKeys(k => !k.equals(ResourceProfile.CPUS)) + rp.taskResources.filterKeys(k => !k.equals(ResourceProfile.CPUS)).toMap } private[spark] def getCustomExecutorResources( rp: ResourceProfile): Map[String, ExecutorResourceRequest] = { - rp.executorResources.filterKeys(k => !ResourceProfile.allSupportedExecutorResources.contains(k)) + rp.executorResources. + filterKeys(k => !ResourceProfile.allSupportedExecutorResources.contains(k)).toMap } /* diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 37f9e0bb483c2..cb024d0852d06 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1912,9 +1912,9 @@ private[spark] class DAGScheduler( * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. * * We will also assume that we've lost all shuffle blocks associated with the executor if the - * executor serves its own blocks (i.e., we're not using external shuffle), the entire slave - * is lost (likely including the shuffle service), or a FetchFailed occurred, in which case we - * presume all shuffle data related to this executor to be lost. + * executor serves its own blocks (i.e., we're not using external shuffle), the entire executor + * process is lost (likely including the shuffle service), or a FetchFailed occurred, in which + * case we presume all shuffle data related to this executor to be lost. * * Optionally the epoch during which the failure was caught can be passed to avoid allowing * stray fetch failures from possibly retriggering the detection of a node as lost. @@ -2273,7 +2273,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case ExecutorLost(execId, reason) => val workerLost = reason match { - case SlaveLost(_, true) => true + case ExecutorProcessLost(_, true) => true case _ => false } dagScheduler.handleExecutorLost(execId, workerLost) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index ee31093ec0652..4141ed799a4e0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler import org.apache.spark.executor.ExecutorExitCode /** - * Represents an explanation for an executor or whole slave failing or exiting. + * Represents an explanation for an executor or whole process failing or exiting. */ private[spark] class ExecutorLossReason(val message: String) extends Serializable { @@ -56,7 +56,7 @@ private [spark] object LossReasonPending extends ExecutorLossReason("Pending los * @param workerLost whether the worker is confirmed lost too (i.e. including shuffle service) */ private[spark] -case class SlaveLost(_message: String = "Slave lost", workerLost: Boolean = false) +case class ExecutorProcessLost(_message: String = "Worker lost", workerLost: Boolean = false) extends ExecutorLossReason(_message) /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 95b0096cade38..f13f1eaeeaa43 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -232,7 +232,7 @@ private[spark] class LiveListenerBus(conf: SparkConf) { // For testing only. private[spark] def findListenersByClass[T <: SparkListenerInterface : ClassTag](): Seq[T] = { - queues.asScala.flatMap { queue => queue.findListenersByClass[T]() } + queues.asScala.flatMap { queue => queue.findListenersByClass[T]() }.toSeq } // For testing only. diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index b382d623806e2..a5858ebf9cdcc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -151,7 +151,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) logInfo(s"Task was denied committing, stage: $stage.$stageAttempt, " + s"partition: $partition, attempt: $attemptNumber") case _ => - // Mark the attempt as failed to blacklist from future commit protocol + // Mark the attempt as failed to exclude from future commit protocol val taskId = TaskIdentifier(stageAttempt, attemptNumber) stageState.failures.getOrElseUpdate(partition, mutable.Set()) += taskId if (stageState.authorizedCommitters(partition) == taskId) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala index bc1431835e258..6112d8ef051e0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala @@ -69,7 +69,7 @@ object SplitInfo { for (host <- mapredSplit.getLocations) { retval += new SplitInfo(inputFormatClazz, host, path, length, mapredSplit) } - retval + retval.toSeq } def toSplitInfo(inputFormatClazz: Class[_], path: String, @@ -79,6 +79,6 @@ object SplitInfo { for (host <- mapreduceSplit.getLocations) { retval += new SplitInfo(inputFormatClazz, host, path, length, mapreduceSplit) } - retval + retval.toSeq } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala index ca48775e77f27..be881481bf4ff 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala @@ -47,19 +47,19 @@ class StatsReportListener extends SparkListener with Logging { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { implicit val sc = stageCompleted this.logInfo(s"Finished stage: ${getStatusDetail(stageCompleted.stageInfo)}") - showMillisDistribution("task runtime:", (info, _) => info.duration, taskInfoMetrics) + showMillisDistribution("task runtime:", (info, _) => info.duration, taskInfoMetrics.toSeq) // Shuffle write showBytesDistribution("shuffle bytes written:", - (_, metric) => metric.shuffleWriteMetrics.bytesWritten, taskInfoMetrics) + (_, metric) => metric.shuffleWriteMetrics.bytesWritten, taskInfoMetrics.toSeq) // Fetch & I/O showMillisDistribution("fetch wait time:", - (_, metric) => metric.shuffleReadMetrics.fetchWaitTime, taskInfoMetrics) + (_, metric) => metric.shuffleReadMetrics.fetchWaitTime, taskInfoMetrics.toSeq) showBytesDistribution("remote bytes read:", - (_, metric) => metric.shuffleReadMetrics.remoteBytesRead, taskInfoMetrics) + (_, metric) => metric.shuffleReadMetrics.remoteBytesRead, taskInfoMetrics.toSeq) showBytesDistribution("task result size:", - (_, metric) => metric.resultSize, taskInfoMetrics) + (_, metric) => metric.resultSize, taskInfoMetrics.toSeq) // Runtime breakdown val runtimePcts = taskInfoMetrics.map { case (info, metrics) => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index b6df216d537e4..11d969e1aba90 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -71,7 +71,7 @@ private[spark] class DirectTaskResult[T]( for (i <- 0 until numUpdates) { _accumUpdates += in.readObject.asInstanceOf[AccumulatorV2[_, _]] } - accumUpdates = _accumUpdates + accumUpdates = _accumUpdates.toSeq } val numMetrics = in.readInt diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index e9e638a3645ac..08f9f3c256e69 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -45,7 +45,7 @@ private[spark] trait TaskScheduler { // Invoked after system has successfully initialized (typically in spark context). // Yarn uses this to bootstrap allocation of resources based on preferred locations, - // wait for slave registrations, etc. + // wait for executor registrations, etc. def postStartHook(): Unit = { } // Disconnect from the cluster. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 2c37fec271766..12bd93286d736 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -137,7 +137,7 @@ private[spark] class TaskSchedulerImpl( private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]] def runningTasksByExecutors: Map[String, Int] = synchronized { - executorIdToRunningTaskIds.toMap.mapValues(_.size) + executorIdToRunningTaskIds.toMap.mapValues(_.size).toMap } // The set of executors we have on each host; this is used to compute hostsAlive, which @@ -526,14 +526,14 @@ private[spark] class TaskSchedulerImpl( } /** - * Called by cluster manager to offer resources on slaves. We respond by asking our active task + * Called by cluster manager to offer resources on workers. We respond by asking our active task * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so * that tasks are balanced across the cluster. */ def resourceOffers( offers: IndexedSeq[WorkerOffer], isAllFreeResources: Boolean = true): Seq[Seq[TaskDescription]] = synchronized { - // Mark each slave as alive and remember its hostname + // Mark each worker as alive and remember its hostname // Also track if new executor is added var newExecAvail = false for (o <- offers) { @@ -719,7 +719,7 @@ private[spark] class TaskSchedulerImpl( if (tasks.nonEmpty) { hasLaunchedTask = true } - return tasks + return tasks.map(_.toSeq) } private def createUnschedulableTaskSetAbortTimer( @@ -765,7 +765,8 @@ private[spark] class TaskSchedulerImpl( }) if (executorIdToRunningTaskIds.contains(execId)) { reason = Some( - SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) + ExecutorProcessLost( + s"Task $tid was lost, so marking the executor as lost as well.")) removeExecutor(execId, reason.get) failedExecutor = Some(execId) } @@ -936,7 +937,7 @@ private[spark] class TaskSchedulerImpl( case None => // We may get multiple executorLost() calls with different loss reasons. For example, - // one may be triggered by a dropped connection from the slave while another may be a + // one may be triggered by a dropped connection from the worker while another may be a // report of executor termination from Mesos. We produce log messages for both so we // eventually report the termination reason. logError(s"Lost an executor $executorId (already removed): $reason") diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 465c0d20de481..bb929c27b6a65 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -132,4 +132,6 @@ private[spark] object CoarseGrainedClusterMessages { // Used internally by executors to shut themselves down. case object Shutdown extends CoarseGrainedClusterMessage + // The message to check if `CoarseGrainedSchedulerBackend` thinks the executor is alive or not. + case class IsExecutorAlive(executorId: String) extends CoarseGrainedClusterMessage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 67638a5f9593c..6b9b4d6fe57e0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -285,6 +285,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp Option(delegationTokens.get()), rp) context.reply(reply) + + case IsExecutorAlive(executorId) => context.reply(isExecutorActive(executorId)) + case e => logError(s"Received unexpected ask ${e}") } @@ -313,9 +316,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def onDisconnected(remoteAddress: RpcAddress): Unit = { addressToExecutorId .get(remoteAddress) - .foreach(removeExecutor(_, SlaveLost("Remote RPC client disassociated. Likely due to " + - "containers exceeding thresholds, or network issues. Check driver logs for WARN " + - "messages."))) + .foreach(removeExecutor(_, + ExecutorProcessLost("Remote RPC client disassociated. Likely due to " + + "containers exceeding thresholds, or network issues. Check driver logs for WARN " + + "messages."))) } // Make fake resource offers on just one executor @@ -379,7 +383,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } - // Remove a disconnected slave from the cluster + // Remove a disconnected executor from the cluster private def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { logDebug(s"Asked to remove executor $executorId with reason $reason") executorDataMap.get(executorId) match { @@ -553,7 +557,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Remove all the lingering executors that should be removed but not yet. The reason might be // because (1) disconnected event is not yet received; (2) executors die silently. executors.foreach { eid => - removeExecutor(eid, SlaveLost("Stale executor after cluster manager re-registered.")) + removeExecutor(eid, + ExecutorProcessLost("Stale executor after cluster manager re-registered.")) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 42c46464d79e1..ec1299a924b5c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -168,7 +168,7 @@ private[spark] class StandaloneSchedulerBackend( fullId: String, message: String, exitStatus: Option[Int], workerLost: Boolean): Unit = { val reason: ExecutorLossReason = exitStatus match { case Some(code) => ExecutorExited(code, exitCausedByApp = true, message) - case None => SlaveLost(message, workerLost = workerLost) + case None => ExecutorProcessLost(message, workerLost = workerLost) } logInfo("Executor %s removed: %s".format(fullId, message)) removeExecutor(fullId.split("/")(1), reason) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 106d272948b9f..0a8d188dc1553 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -386,7 +386,8 @@ private[spark] class AppStatusStore( stageAttemptId: Int, offset: Int, length: Int, - sortBy: v1.TaskSorting): Seq[v1.TaskData] = { + sortBy: v1.TaskSorting, + statuses: JList[v1.TaskStatus]): Seq[v1.TaskData] = { val (indexName, ascending) = sortBy match { case v1.TaskSorting.ID => (None, true) @@ -395,7 +396,7 @@ private[spark] class AppStatusStore( case v1.TaskSorting.DECREASING_RUNTIME => (Some(TaskIndexNames.EXEC_RUN_TIME), false) } - taskList(stageId, stageAttemptId, offset, length, indexName, ascending) + taskList(stageId, stageAttemptId, offset, length, indexName, ascending, statuses) } def taskList( @@ -404,7 +405,8 @@ private[spark] class AppStatusStore( offset: Int, length: Int, sortBy: Option[String], - ascending: Boolean): Seq[v1.TaskData] = { + ascending: Boolean, + statuses: JList[v1.TaskStatus] = List().asJava): Seq[v1.TaskData] = { val stageKey = Array(stageId, stageAttemptId) val base = store.view(classOf[TaskDataWrapper]) val indexed = sortBy match { @@ -417,7 +419,13 @@ private[spark] class AppStatusStore( } val ordered = if (ascending) indexed else indexed.reverse() - val taskDataWrapperIter = ordered.skip(offset).max(length).asScala + val taskDataWrapperIter = if (statuses != null && !statuses.isEmpty) { + val statusesStr = statuses.asScala.map(_.toString).toSet + ordered.asScala.filter(s => statusesStr.contains(s.status)).slice(offset, offset + length) + } else { + ordered.skip(offset).max(length).asScala + } + constructTaskDataList(taskDataWrapperIter) } diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 86cb4fe138773..81478214994b0 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -33,7 +33,7 @@ import org.apache.spark.scheduler.{AccumulableInfo, StageInfo, TaskInfo} import org.apache.spark.status.api.v1 import org.apache.spark.storage.{RDDInfo, StorageLevel} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.AccumulatorContext +import org.apache.spark.util.{AccumulatorContext, Utils} import org.apache.spark.util.collection.OpenHashSet /** @@ -307,7 +307,7 @@ private[spark] class LiveExecutor(val executorId: String, _addTime: Long) extend // peak values for executor level metrics val peakExecutorMetrics = new ExecutorMetrics() - def hostname: String = if (host != null) host else hostPort.split(":")(0) + def hostname: String = if (host != null) host else Utils.parseHostPort(hostPort)._1 override protected def doUpdate(): Any = { val memoryMetrics = if (totalOnHeap >= 0) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index 44ee322a22a10..05a7e96882d77 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -96,8 +96,9 @@ private[v1] class StagesResource extends BaseAppResource { @PathParam("stageAttemptId") stageAttemptId: Int, @DefaultValue("0") @QueryParam("offset") offset: Int, @DefaultValue("20") @QueryParam("length") length: Int, - @DefaultValue("ID") @QueryParam("sortBy") sortBy: TaskSorting): Seq[TaskData] = { - withUI(_.store.taskList(stageId, stageAttemptId, offset, length, sortBy)) + @DefaultValue("ID") @QueryParam("sortBy") sortBy: TaskSorting, + @QueryParam("status") statuses: JList[TaskStatus]): Seq[TaskData] = { + withUI(_.store.taskList(stageId, stageAttemptId, offset, length, sortBy, statuses)) } // This api needs to stay formatted exactly as it is below, since, it is being used by the diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index b40f7304b7ce2..5a164823297f9 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -26,6 +26,7 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1._ import org.apache.spark.ui.scope._ +import org.apache.spark.util.Utils import org.apache.spark.util.kvstore.KVIndex private[spark] case class AppStatusStoreMetadata(version: Long) @@ -57,7 +58,7 @@ private[spark] class ExecutorSummaryWrapper(val info: ExecutorSummary) { private def active: Boolean = info.isActive @JsonIgnore @KVIndex("host") - val host: String = info.hostPort.split(":")(0) + val host: String = Utils.parseHostPort(info.hostPort)._1 } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index 219a0e799cc73..95d901f292971 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -367,7 +367,7 @@ private[storage] class BlockInfoManager extends Logging { notifyAll() - blocksWithReleasedLocks + blocksWithReleasedLocks.toSeq } /** Returns the number of locks held by the given task. Used only for testing. */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e0478ad09601d..6eec288015380 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -226,9 +226,9 @@ private[spark] class BlockManager( private val maxFailuresBeforeLocationRefresh = conf.get(config.BLOCK_FAILURES_BEFORE_LOCATION_REFRESH) - private val slaveEndpoint = rpcEnv.setupEndpoint( + private val storageEndpoint = rpcEnv.setupEndpoint( "BlockManagerEndpoint" + BlockManager.ID_GENERATOR.next, - new BlockManagerSlaveEndpoint(rpcEnv, this, mapOutputTracker)) + new BlockManagerStorageEndpoint(rpcEnv, this, mapOutputTracker)) // Pending re-registration action being executed asynchronously or null if none is pending. // Accesses should synchronize on asyncReregisterLock. @@ -465,7 +465,7 @@ private[spark] class BlockManager( diskBlockManager.localDirsString, maxOnHeapMemory, maxOffHeapMemory, - slaveEndpoint) + storageEndpoint) blockManagerId = if (idFromMaster != null) idFromMaster else id @@ -543,8 +543,8 @@ private[spark] class BlockManager( * an executor crash. * * This function deliberately fails silently if the master returns false (indicating that - * the slave needs to re-register). The error condition will be detected again by the next - * heart beat attempt or new block registration and another try to re-register all blocks + * the storage endpoint needs to re-register). The error condition will be detected again by the + * next heart beat attempt or new block registration and another try to re-register all blocks * will be made then. */ private def reportAllBlocks(): Unit = { @@ -568,7 +568,7 @@ private[spark] class BlockManager( // TODO: We might need to rate limit re-registering. logInfo(s"BlockManager $blockManagerId re-registering with master") master.registerBlockManager(blockManagerId, diskBlockManager.localDirsString, maxOnHeapMemory, - maxOffHeapMemory, slaveEndpoint) + maxOffHeapMemory, storageEndpoint) reportAllBlocks() } @@ -718,7 +718,7 @@ private[spark] class BlockManager( * * droppedMemorySize exists to account for when the block is dropped from memory to disk (so * it is still valid). This ensures that update in master will compensate for the increase in - * memory on slave. + * memory on the storage endpoint. */ private def reportBlockStatus( blockId: BlockId, @@ -736,7 +736,7 @@ private[spark] class BlockManager( /** * Actually send a UpdateBlockInfo message. Returns the master's response, * which will be true if the block was successfully recorded and false if - * the slave needs to re-register. + * the storage endpoint needs to re-register. */ private def tryToReportBlockStatus( blockId: BlockId, @@ -934,7 +934,7 @@ private[spark] class BlockManager( require(blockId != null, "BlockId is null") // Because all the remote blocks are registered in driver, it is not necessary to ask - // all the slave executors to get block status. + // all the storage endpoints to get block status. val locationsAndStatusOption = master.getLocationsAndStatus(blockId, blockManagerId.host) if (locationsAndStatusOption.isEmpty) { logDebug(s"Block $blockId is unknown by block manager master") @@ -1960,7 +1960,7 @@ private[spark] class BlockManager( } remoteBlockTempFileManager.stop() diskBlockManager.stop() - rpcEnv.stop(slaveEndpoint) + rpcEnv.stop(storageEndpoint) blockInfoManager.clear() memoryStore.clear() futureExecutionContext.shutdownNow() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 3cfa5d2a25818..93492cc6d7db6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -71,10 +71,10 @@ class BlockManagerMaster( localDirs: Array[String], maxOnHeapMemSize: Long, maxOffHeapMemSize: Long, - slaveEndpoint: RpcEndpointRef): BlockManagerId = { + storageEndpoint: RpcEndpointRef): BlockManagerId = { logInfo(s"Registering BlockManager $id") val updatedId = driverEndpoint.askSync[BlockManagerId]( - RegisterBlockManager(id, localDirs, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint)) + RegisterBlockManager(id, localDirs, maxOnHeapMemSize, maxOffHeapMemSize, storageEndpoint)) logInfo(s"Registered BlockManager $updatedId") updatedId } @@ -128,7 +128,7 @@ class BlockManagerMaster( } /** - * Remove a block from the slaves that have it. This can only be used to remove + * Remove a block from the storage endpoints that have it. This can only be used to remove * blocks that the driver knows about. */ def removeBlock(blockId: BlockId): Unit = { @@ -142,7 +142,8 @@ class BlockManagerMaster( logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}", e) )(ThreadUtils.sameThread) if (blocking) { - timeout.awaitResult(future) + // the underlying Futures will timeout anyway, so it's safe to use infinite timeout here + RpcUtils.INFINITE_TIMEOUT.awaitResult(future) } } @@ -153,7 +154,8 @@ class BlockManagerMaster( logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}", e) )(ThreadUtils.sameThread) if (blocking) { - timeout.awaitResult(future) + // the underlying Futures will timeout anyway, so it's safe to use infinite timeout here + RpcUtils.INFINITE_TIMEOUT.awaitResult(future) } } @@ -166,7 +168,8 @@ class BlockManagerMaster( s" with removeFromMaster = $removeFromMaster - ${e.getMessage}", e) )(ThreadUtils.sameThread) if (blocking) { - timeout.awaitResult(future) + // the underlying Futures will timeout anyway, so it's safe to use infinite timeout here + RpcUtils.INFINITE_TIMEOUT.awaitResult(future) } } @@ -190,14 +193,14 @@ class BlockManagerMaster( * Return the block's status on all block managers, if any. NOTE: This is a * potentially expensive operation and should only be used for testing. * - * If askSlaves is true, this invokes the master to query each block manager for the most - * updated block statuses. This is useful when the master is not informed of the given block + * If askStorageEndpoints is true, this invokes the master to query each block manager for the + * most updated block statuses. This is useful when the master is not informed of the given block * by all block managers. */ def getBlockStatus( blockId: BlockId, - askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = { - val msg = GetBlockStatus(blockId, askSlaves) + askStorageEndpoints: Boolean = true): Map[BlockManagerId, BlockStatus] = { + val msg = GetBlockStatus(blockId, askStorageEndpoints) /* * To avoid potential deadlocks, the use of Futures is necessary, because the master endpoint * should not block on waiting for a block manager, which can in turn be waiting for the @@ -226,14 +229,14 @@ class BlockManagerMaster( * Return a list of ids of existing blocks such that the ids match the given filter. NOTE: This * is a potentially expensive operation and should only be used for testing. * - * If askSlaves is true, this invokes the master to query each block manager for the most - * updated block statuses. This is useful when the master is not informed of the given block + * If askStorageEndpoints is true, this invokes the master to query each block manager for the + * most updated block statuses. This is useful when the master is not informed of the given block * by all block managers. */ def getMatchingBlockIds( filter: BlockId => Boolean, - askSlaves: Boolean): Seq[BlockId] = { - val msg = GetMatchingBlockIds(filter, askSlaves) + askStorageEndpoints: Boolean): Seq[BlockId] = { + val msg = GetMatchingBlockIds(filter, askStorageEndpoints) val future = driverEndpoint.askSync[Future[Seq[BlockId]]](msg) timeout.awaitResult(future) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index d936420a99276..2a4817797a87c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -23,8 +23,9 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.{ExecutionContext, Future, TimeoutException} import scala.util.Random +import scala.util.control.NonFatal import com.google.common.cache.CacheBuilder @@ -32,14 +33,15 @@ import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.{config, Logging} import org.apache.spark.network.shuffle.ExternalBlockStoreClient -import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} +import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEndpointAddress, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.{CoarseGrainedClusterMessages, CoarseGrainedSchedulerBackend} import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} /** * BlockManagerMasterEndpoint is an [[IsolatedRpcEndpoint]] on the master node to track statuses - * of all slaves' block managers. + * of all the storage endpoints' block managers. */ private[spark] class BlockManagerMasterEndpoint( @@ -95,9 +97,12 @@ class BlockManagerMasterEndpoint( private val externalShuffleServiceRddFetchEnabled: Boolean = externalBlockStoreClient.isDefined private val externalShuffleServicePort: Int = StorageUtils.externalShuffleServicePort(conf) + private lazy val driverEndpoint = + RpcUtils.makeDriverRef(CoarseGrainedSchedulerBackend.ENDPOINT_NAME, conf, rpcEnv) + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterBlockManager(id, localDirs, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint) => - context.reply(register(id, localDirs, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint)) + case RegisterBlockManager(id, localDirs, maxOnHeapMemSize, maxOffHeapMemSize, endpoint) => + context.reply(register(id, localDirs, maxOnHeapMemSize, maxOffHeapMemSize, endpoint)) case _updateBlockInfo @ UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => @@ -130,14 +135,14 @@ class BlockManagerMasterEndpoint( case GetStorageStatus => context.reply(storageStatus) - case GetBlockStatus(blockId, askSlaves) => - context.reply(blockStatus(blockId, askSlaves)) + case GetBlockStatus(blockId, askStorageEndpoints) => + context.reply(blockStatus(blockId, askStorageEndpoints)) case IsExecutorAlive(executorId) => context.reply(blockManagerIdByExecutor.contains(executorId)) - case GetMatchingBlockIds(filter, askSlaves) => - context.reply(getMatchingBlockIds(filter, askSlaves)) + case GetMatchingBlockIds(filter, askStorageEndpoints) => + context.reply(getMatchingBlockIds(filter, askStorageEndpoints)) case RemoveRdd(rddId) => context.reply(removeRdd(rddId)) @@ -168,16 +173,60 @@ class BlockManagerMasterEndpoint( stop() } + /** + * A function that used to handle the failures when removing blocks. In general, the failure + * should be considered as non-fatal since it won't cause any correctness issue. Therefore, + * this function would prefer to log the exception and return the default value. We only throw + * the exception when there's a TimeoutException from an active executor, which implies the + * unhealthy status of the executor while the driver still not be aware of it. + * @param blockType should be one of "RDD", "shuffle", "broadcast", "block", used for log + * @param blockId the string value of a certain block id, used for log + * @param bmId the BlockManagerId of the BlockManager, where we're trying to remove the block + * @param defaultValue the return value of a failure removal. e.g., 0 means no blocks are removed + * @tparam T the generic type for defaultValue, Int or Boolean. + * @return the defaultValue or throw exception if the executor is active but reply late. + */ + private def handleBlockRemovalFailure[T]( + blockType: String, + blockId: String, + bmId: BlockManagerId, + defaultValue: T): PartialFunction[Throwable, T] = { + case e: IOException => + logWarning(s"Error trying to remove $blockType $blockId" + + s" from block manager $bmId", e) + defaultValue + + case t: TimeoutException => + val executorId = bmId.executorId + val isAlive = try { + driverEndpoint.askSync[Boolean](CoarseGrainedClusterMessages.IsExecutorAlive(executorId)) + } catch { + // ignore the non-fatal error from driverEndpoint since the caller doesn't really + // care about the return result of removing blocks. And so we could avoid breaking + // down the whole application. + case NonFatal(e) => + logError(s"Fail to know the executor $executorId is alive or not.", e) + false + } + if (!isAlive) { + logWarning(s"Error trying to remove $blockType $blockId. " + + s"The executor $executorId may have been lost.", t) + defaultValue + } else { + throw t + } + } + private def removeRdd(rddId: Int): Future[Seq[Int]] = { // First remove the metadata for the given RDD, and then asynchronously remove the blocks - // from the slaves. + // from the storage endpoints. - // The message sent to the slaves to remove the RDD + // The message sent to the storage endpoints to remove the RDD val removeMsg = RemoveRdd(rddId) // Find all blocks for the given RDD, remove the block from both blockLocations and // the blockManagerInfo that is tracking the blocks and create the futures which asynchronously - // remove the blocks from slaves and gives back the number of removed blocks + // remove the blocks from storage endpoints and gives back the number of removed blocks val blocks = blockLocations.asScala.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) val blocksToDeleteByShuffleService = new mutable.HashMap[BlockManagerId, mutable.HashSet[RDDBlockId]] @@ -206,11 +255,9 @@ class BlockManagerMasterEndpoint( } } val removeRddFromExecutorsFutures = blockManagerInfo.values.map { bmInfo => - bmInfo.slaveEndpoint.ask[Int](removeMsg).recover { - case e: IOException => - logWarning(s"Error trying to remove RDD ${removeMsg.rddId} " + - s"from block manager ${bmInfo.blockManagerId}", e) - 0 // zero blocks were removed + bmInfo.storageEndpoint.ask[Int](removeMsg).recover { + // use 0 as default value means no blocks were removed + handleBlockRemovalFailure("RDD", rddId.toString, bmInfo.blockManagerId, 0) } }.toSeq @@ -229,13 +276,15 @@ class BlockManagerMasterEndpoint( Future.sequence(removeRddFromExecutorsFutures ++ removeRddBlockViaExtShuffleServiceFutures) } - private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { // Nothing to do in the BlockManagerMasterEndpoint data structures val removeMsg = RemoveShuffle(shuffleId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveEndpoint.ask[Boolean](removeMsg) + bm.storageEndpoint.ask[Boolean](removeMsg).recover { + // use false as default value means no shuffle data were removed + handleBlockRemovalFailure("shuffle", shuffleId.toString, bm.blockManagerId, false) + } }.toSeq ) } @@ -251,11 +300,9 @@ class BlockManagerMasterEndpoint( removeFromDriver || !info.blockManagerId.isDriver } val futures = requiredBlockManagers.map { bm => - bm.slaveEndpoint.ask[Int](removeMsg).recover { - case e: IOException => - logWarning(s"Error trying to remove broadcast $broadcastId from block manager " + - s"${bm.blockManagerId}", e) - 0 // zero blocks were removed + bm.storageEndpoint.ask[Int](removeMsg).recover { + // use 0 as default value means no blocks were removed + handleBlockRemovalFailure("broadcast", broadcastId.toString, bm.blockManagerId, 0) } }.toSeq @@ -295,7 +342,7 @@ class BlockManagerMasterEndpoint( blockManagerInfo.get(candidateBMId).foreach { bm => val remainingLocations = locations.toSeq.filter(bm => bm != candidateBMId) val replicateMsg = ReplicateBlock(blockId, remainingLocations, maxReplicas) - bm.slaveEndpoint.ask[Boolean](replicateMsg) + bm.storageEndpoint.ask[Boolean](replicateMsg) } } } @@ -313,14 +360,14 @@ class BlockManagerMasterEndpoint( /** * Decommission the given Seq of blockmanagers * - Adds these block managers to decommissioningBlockManagerSet Set - * - Sends the DecommissionBlockManager message to each of the [[BlockManagerSlaveEndpoint]] + * - Sends the DecommissionBlockManager message to each of the [[BlockManagerReplicaEndpoint]] */ def decommissionBlockManagers(blockManagerIds: Seq[BlockManagerId]): Future[Seq[Unit]] = { val newBlockManagersToDecommission = blockManagerIds.toSet.diff(decommissioningBlockManagerSet) val futures = newBlockManagersToDecommission.map { blockManagerId => decommissioningBlockManagerSet.add(blockManagerId) val info = blockManagerInfo(blockManagerId) - info.slaveEndpoint.ask[Unit](DecommissionBlockManager) + info.storageEndpoint.ask[Unit](DecommissionBlockManager) } Future.sequence{ futures.toSeq } } @@ -343,18 +390,21 @@ class BlockManagerMasterEndpoint( }.toSeq } - // Remove a block from the slaves that have it. This can only be used to remove + // Remove a block from the workers that have it. This can only be used to remove // blocks that the master knows about. private def removeBlockFromWorkers(blockId: BlockId): Unit = { val locations = blockLocations.get(blockId) if (locations != null) { locations.foreach { blockManagerId: BlockManagerId => val blockManager = blockManagerInfo.get(blockManagerId) - if (blockManager.isDefined) { - // Remove the block from the slave's BlockManager. + blockManager.foreach { bm => + // Remove the block from the BlockManager. // Doesn't actually wait for a confirmation and the message might get lost. // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveEndpoint.ask[Boolean](RemoveBlock(blockId)) + bm.storageEndpoint.ask[Boolean](RemoveBlock(blockId)).recover { + // use false as default value means no blocks were removed + handleBlockRemovalFailure("block", blockId.toString, bm.blockManagerId, false) + } } } } @@ -378,13 +428,13 @@ class BlockManagerMasterEndpoint( * Return the block's status for all block managers, if any. NOTE: This is a * potentially expensive operation and should only be used for testing. * - * If askSlaves is true, the master queries each block manager for the most updated block - * statuses. This is useful when the master is not informed of the given block by all block + * If askStorageEndpoints is true, the master queries each block manager for the most updated + * block statuses. This is useful when the master is not informed of the given block by all block * managers. */ private def blockStatus( blockId: BlockId, - askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { + askStorageEndpoints: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { val getBlockStatus = GetBlockStatus(blockId) /* * Rather than blocking on the block status query, master endpoint should simply return @@ -393,8 +443,8 @@ class BlockManagerMasterEndpoint( */ blockManagerInfo.values.map { info => val blockStatusFuture = - if (askSlaves) { - info.slaveEndpoint.ask[Option[BlockStatus]](getBlockStatus) + if (askStorageEndpoints) { + info.storageEndpoint.ask[Option[BlockStatus]](getBlockStatus) } else { Future { info.getStatus(blockId) } } @@ -406,19 +456,19 @@ class BlockManagerMasterEndpoint( * Return the ids of blocks present in all the block managers that match the given filter. * NOTE: This is a potentially expensive operation and should only be used for testing. * - * If askSlaves is true, the master queries each block manager for the most updated block - * statuses. This is useful when the master is not informed of the given block by all block + * If askStorageEndpoints is true, the master queries each block manager for the most updated + * block statuses. This is useful when the master is not informed of the given block by all block * managers. */ private def getMatchingBlockIds( filter: BlockId => Boolean, - askSlaves: Boolean): Future[Seq[BlockId]] = { + askStorageEndpoints: Boolean): Future[Seq[BlockId]] = { val getMatchingBlockIds = GetMatchingBlockIds(filter) Future.sequence( blockManagerInfo.values.map { info => val future = - if (askSlaves) { - info.slaveEndpoint.ask[Seq[BlockId]](getMatchingBlockIds) + if (askStorageEndpoints) { + info.storageEndpoint.ask[Seq[BlockId]](getMatchingBlockIds) } else { Future { info.blocks.asScala.keys.filter(filter).toSeq } } @@ -441,7 +491,7 @@ class BlockManagerMasterEndpoint( localDirs: Array[String], maxOnHeapMemSize: Long, maxOffHeapMemSize: Long, - slaveEndpoint: RpcEndpointRef): BlockManagerId = { + storageEndpoint: RpcEndpointRef): BlockManagerId = { // the dummy id is not expected to contain the topology information. // we get that info here and respond back with a more fleshed out block manager id val id = BlockManagerId( @@ -476,7 +526,7 @@ class BlockManagerMasterEndpoint( } blockManagerInfo(id) = new BlockManagerInfo(id, System.currentTimeMillis(), - maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint, externalShuffleServiceBlockStatus) + maxOnHeapMemSize, maxOffHeapMemSize, storageEndpoint, externalShuffleServiceBlockStatus) } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxOnHeapMemSize + maxOffHeapMemSize, Some(maxOnHeapMemSize), Some(maxOffHeapMemSize))) @@ -530,7 +580,7 @@ class BlockManagerMasterEndpoint( } } - // Remove the block from master tracking if it has been removed on all slaves. + // Remove the block from master tracking if it has been removed on all endpoints. if (locations.size == 0) { blockLocations.remove(blockId) } @@ -591,14 +641,14 @@ class BlockManagerMasterEndpoint( } /** - * Returns an [[RpcEndpointRef]] of the [[BlockManagerSlaveEndpoint]] for sending RPC messages. + * Returns an [[RpcEndpointRef]] of the [[BlockManagerReplicaEndpoint]] for sending RPC messages. */ private def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { for ( blockManagerId <- blockManagerIdByExecutor.get(executorId); info <- blockManagerInfo.get(blockManagerId) ) yield { - info.slaveEndpoint + info.storageEndpoint } } @@ -622,7 +672,7 @@ private[spark] class BlockManagerInfo( timeMs: Long, val maxOnHeapMem: Long, val maxOffHeapMem: Long, - val slaveEndpoint: RpcEndpointRef, + val storageEndpoint: RpcEndpointRef, val externalShuffleServiceBlockStatus: Option[JHashMap[BlockId, BlockStatus]]) extends Logging { @@ -656,7 +706,7 @@ private[spark] class BlockManagerInfo( var originalLevel: StorageLevel = StorageLevel.NONE if (blockExists) { - // The block exists on the slave already. + // The block exists on the storage endpoint already. val blockStatus: BlockStatus = _blocks.get(blockId) originalLevel = blockStatus.storageLevel originalMemSize = blockStatus.memSize diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 7d4f2fff5c34c..bbc076cea9ba8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -24,37 +24,37 @@ import org.apache.spark.util.Utils private[spark] object BlockManagerMessages { ////////////////////////////////////////////////////////////////////////////////// - // Messages from the master to slaves. + // Messages from the master to storage endpoints. ////////////////////////////////////////////////////////////////////////////////// - sealed trait ToBlockManagerSlave + sealed trait ToBlockManagerMasterStorageEndpoint - // Remove a block from the slaves that have it. This can only be used to remove + // Remove a block from the storage endpoints that have it. This can only be used to remove // blocks that the master knows about. - case class RemoveBlock(blockId: BlockId) extends ToBlockManagerSlave + case class RemoveBlock(blockId: BlockId) extends ToBlockManagerMasterStorageEndpoint // Replicate blocks that were lost due to executor failure case class ReplicateBlock(blockId: BlockId, replicas: Seq[BlockManagerId], maxReplicas: Int) - extends ToBlockManagerSlave + extends ToBlockManagerMasterStorageEndpoint - case object DecommissionBlockManager extends ToBlockManagerSlave + case object DecommissionBlockManager extends ToBlockManagerMasterStorageEndpoint // Remove all blocks belonging to a specific RDD. - case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave + case class RemoveRdd(rddId: Int) extends ToBlockManagerMasterStorageEndpoint // Remove all blocks belonging to a specific shuffle. - case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave + case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerMasterStorageEndpoint // Remove all blocks belonging to a specific broadcast. case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) - extends ToBlockManagerSlave + extends ToBlockManagerMasterStorageEndpoint /** * Driver to Executor message to trigger a thread dump. */ - case object TriggerThreadDump extends ToBlockManagerSlave + case object TriggerThreadDump extends ToBlockManagerMasterStorageEndpoint ////////////////////////////////////////////////////////////////////////////////// - // Messages from slaves to the master. + // Messages from storage endpoints to the master. ////////////////////////////////////////////////////////////////////////////////// sealed trait ToBlockManagerMaster @@ -132,10 +132,10 @@ private[spark] object BlockManagerMessages { case class GetReplicateInfoForRDDBlocks(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - case class GetBlockStatus(blockId: BlockId, askSlaves: Boolean = true) + case class GetBlockStatus(blockId: BlockId, askStorageEndpoints: Boolean = true) extends ToBlockManagerMaster - case class GetMatchingBlockIds(filter: BlockId => Boolean, askSlaves: Boolean = true) + case class GetMatchingBlockIds(filter: BlockId => Boolean, askStorageEndpoints: Boolean = true) extends ToBlockManagerMaster case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala similarity index 94% rename from core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala rename to core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala index a3a7149103491..a69bebc23c661 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala @@ -27,17 +27,17 @@ import org.apache.spark.util.{ThreadUtils, Utils} /** * An RpcEndpoint to take commands from the master to execute options. For example, - * this is used to remove blocks from the slave's BlockManager. + * this is used to remove blocks from the storage endpoint's BlockManager. */ private[storage] -class BlockManagerSlaveEndpoint( +class BlockManagerStorageEndpoint( override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) extends IsolatedRpcEndpoint with Logging { private val asyncThreadPool = - ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool", 100) + ThreadUtils.newDaemonCachedThreadPool("block-manager-storage-async-thread-pool", 100) private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index bf76eef443e81..5db4965b67347 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -98,7 +98,7 @@ private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolea } }.filter(_ != null).flatMap { dir => val files = dir.listFiles() - if (files != null) files else Seq.empty + if (files != null) files.toSeq else Seq.empty } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 5efbc0703f729..a2843da0561e0 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -368,25 +368,25 @@ final class ShuffleBlockFetcherIterator( collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = { val iterator = blockInfos.iterator var curRequestSize = 0L - var curBlocks = new ArrayBuffer[FetchBlockInfo] + var curBlocks = Seq.empty[FetchBlockInfo] while (iterator.hasNext) { val (blockId, size, mapIndex) = iterator.next() assertPositiveBlockSize(blockId, size) - curBlocks += FetchBlockInfo(blockId, size, mapIndex) + curBlocks = curBlocks ++ Seq(FetchBlockInfo(blockId, size, mapIndex)) curRequestSize += size // For batch fetch, the actual block in flight should count for merged block. val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) { curBlocks = createFetchRequests(curBlocks, address, isLast = false, - collectedRemoteRequests).to[ArrayBuffer] + collectedRemoteRequests) curRequestSize = curBlocks.map(_.size).sum } } // Add in the final request if (curBlocks.nonEmpty) { curBlocks = createFetchRequests(curBlocks, address, isLast = true, - collectedRemoteRequests).to[ArrayBuffer] + collectedRemoteRequests) curRequestSize = curBlocks.map(_.size).sum } } @@ -928,7 +928,7 @@ object ShuffleBlockFetcherIterator { } else { blocks } - result + result.toSeq } /** diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 087a22d6c6140..a070cc9c7b39d 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -443,7 +443,7 @@ private[spark] object UIUtils extends Logging { case None => {getHeaderContent(x._1)} } - } + }.toSeq } {headerRow} diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 9faa3dcf2cdf2..a4e87704927c6 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -58,11 +58,11 @@ private[spark] abstract class WebUI( private val className = Utils.getFormattedClassName(this) def getBasePath: String = basePath - def getTabs: Seq[WebUITab] = tabs - def getHandlers: Seq[ServletContextHandler] = handlers + def getTabs: Seq[WebUITab] = tabs.toSeq + def getHandlers: Seq[ServletContextHandler] = handlers.toSeq def getDelegatingHandlers: Seq[DelegatingServletContextHandler] = { - handlers.map(new DelegatingServletContextHandler(_)) + handlers.map(new DelegatingServletContextHandler(_)).toSeq } /** Attaches a tab to this UI, along with all of its attached pages. */ diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 066512d159d00..4e76ea289ede6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -259,11 +259,11 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We } val activeJobsTable = - jobsTable(request, "active", "activeJob", activeJobs, killEnabled = parent.killEnabled) + jobsTable(request, "active", "activeJob", activeJobs.toSeq, killEnabled = parent.killEnabled) val completedJobsTable = - jobsTable(request, "completed", "completedJob", completedJobs, killEnabled = false) + jobsTable(request, "completed", "completedJob", completedJobs.toSeq, killEnabled = false) val failedJobsTable = - jobsTable(request, "failed", "failedJob", failedJobs, killEnabled = false) + jobsTable(request, "failed", "failedJob", failedJobs.toSeq, killEnabled = false) val shouldShowActiveJobs = activeJobs.nonEmpty val shouldShowCompletedJobs = completedJobs.nonEmpty @@ -330,7 +330,7 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We var content = summary - content ++= makeTimeline(activeJobs ++ completedJobs ++ failedJobs, + content ++= makeTimeline((activeJobs ++ completedJobs ++ failedJobs).toSeq, store.executorList(false), startTime) if (shouldShowActiveJobs) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 542dc39eee4f0..bba5e3dda6c47 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -288,20 +288,20 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP } val activeStagesTable = - new StageTableBase(store, request, activeStages, "active", "activeStage", parent.basePath, - basePath, parent.isFairScheduler, + new StageTableBase(store, request, activeStages.toSeq, "active", "activeStage", + parent.basePath, basePath, parent.isFairScheduler, killEnabled = parent.killEnabled, isFailedStage = false) val pendingOrSkippedStagesTable = - new StageTableBase(store, request, pendingOrSkippedStages, pendingOrSkippedTableId, + new StageTableBase(store, request, pendingOrSkippedStages.toSeq, pendingOrSkippedTableId, "pendingStage", parent.basePath, basePath, parent.isFairScheduler, killEnabled = false, isFailedStage = false) val completedStagesTable = - new StageTableBase(store, request, completedStages, "completed", "completedStage", + new StageTableBase(store, request, completedStages.toSeq, "completed", "completedStage", parent.basePath, basePath, parent.isFairScheduler, killEnabled = false, isFailedStage = false) val failedStagesTable = - new StageTableBase(store, request, failedStages, "failed", "failedStage", parent.basePath, - basePath, parent.isFairScheduler, + new StageTableBase(store, request, failedStages.toSeq, "failed", "failedStage", + parent.basePath, basePath, parent.isFairScheduler, killEnabled = false, isFailedStage = true) val shouldShowActiveStages = activeStages.nonEmpty @@ -391,7 +391,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP var content = summary val appStartTime = store.applicationInfo().attempts.head.startTime.getTime() - content ++= makeTimeline(activeStages ++ completedStages ++ failedStages, + content ++= makeTimeline((activeStages ++ completedStages ++ failedStages).toSeq, store.executorList(false), appStartTime) val operationGraphContent = store.asOption(store.operationGraphForJob(jobId)) match { diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 842ee7aaf49bf..f8d9279c2404f 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -81,11 +81,11 @@ private[spark] class RDDOperationCluster( /** Return all the nodes which are cached. */ def getCachedNodes: Seq[RDDOperationNode] = { - _childNodes.filter(_.cached) ++ _childClusters.flatMap(_.getCachedNodes) + (_childNodes.filter(_.cached) ++ _childClusters.flatMap(_.getCachedNodes)).toSeq } def getBarrierClusters: Seq[RDDOperationCluster] = { - _childClusters.filter(_.barrier) ++ _childClusters.flatMap(_.getBarrierClusters) + (_childClusters.filter(_.barrier) ++ _childClusters.flatMap(_.getBarrierClusters)).toSeq } def canEqual(other: Any): Boolean = other.isInstanceOf[RDDOperationCluster] @@ -210,7 +210,7 @@ private[spark] object RDDOperationGraph extends Logging { } } - RDDOperationGraph(internalEdges, outgoingEdges, incomingEdges, rootCluster) + RDDOperationGraph(internalEdges.toSeq, outgoingEdges.toSeq, incomingEdges.toSeq, rootCluster) } /** diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index ced3f9d15720d..ceaddb4306579 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -328,11 +328,11 @@ private[spark] object JsonProtocol { ("Accumulables" -> accumulablesToJson(taskInfo.accumulables)) } - private lazy val accumulableBlacklist = Set("internal.metrics.updatedBlockStatuses") + private lazy val accumulableExcludeList = Set("internal.metrics.updatedBlockStatuses") def accumulablesToJson(accumulables: Iterable[AccumulableInfo]): JArray = { JArray(accumulables - .filterNot(_.name.exists(accumulableBlacklist.contains)) + .filterNot(_.name.exists(accumulableExcludeList.contains)) .toList.map(accumulableInfoToJson)) } @@ -1078,8 +1078,12 @@ private[spark] object JsonProtocol { val blockManagerAddress = blockManagerIdFromJson(json \ "Block Manager Address") val shuffleId = (json \ "Shuffle ID").extract[Int] val mapId = (json \ "Map ID").extract[Long] - val mapIndex = (json \ "Map Index") match { - case JNothing => 0 + val mapIndex = json \ "Map Index" match { + case JNothing => + // Note, we use the invalid value Int.MinValue here to fill the map index for backward + // compatibility. Otherwise, the fetch failed event will be dropped when the history + // server loads the event log written by the Spark version before 3.0. + Int.MinValue case x => x.extract[Int] } val reduceId = (json \ "Reduce ID").extract[Int] @@ -1210,7 +1214,8 @@ private[spark] object JsonProtocol { case Some(id) => id.extract[Int] case None => ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID } - new ExecutorInfo(executorHost, totalCores, logUrls, attributes, resources, resourceProfileId) + new ExecutorInfo(executorHost, totalCores, logUrls, attributes.toMap, resources.toMap, + resourceProfileId) } def blockUpdatedInfoFromJson(json: JValue): BlockUpdatedInfo = { diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index 7272b375e5388..0e4debc595345 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.util +import scala.concurrent.duration._ + import org.apache.spark.SparkConf import org.apache.spark.internal.config import org.apache.spark.internal.config.Network._ @@ -54,6 +56,14 @@ private[spark] object RpcUtils { RpcTimeout(conf, Seq(RPC_LOOKUP_TIMEOUT.key, NETWORK_TIMEOUT.key), "120s") } + /** + * Infinite timeout is used internally, so there's no timeout configuration property that + * controls it. Therefore, we use "infinite" without any specific reason as its timeout + * configuration property. And its timeout property should never be accessed since infinite + * means we never timeout. + */ + val INFINITE_TIMEOUT = new RpcTimeout(Long.MaxValue.nanos, "infinite") + private val MAX_MESSAGE_SIZE_IN_MB = Int.MaxValue / 1024 / 1024 /** Returns the configured max message size for messages in bytes. */ diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 9636fe88c77c2..35d60bb514405 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1026,13 +1026,27 @@ private[spark] object Utils extends Logging { customHostname.getOrElse(InetAddresses.toUriString(localIpAddress)) } + /** + * Checks if the host contains only valid hostname/ip without port + * NOTE: Incase of IPV6 ip it should be enclosed inside [] + */ def checkHost(host: String): Unit = { - assert(host != null && host.indexOf(':') == -1, s"Expected hostname (not IP) but got $host") + if (host != null && host.split(":").length > 2) { + assert(host.startsWith("[") && host.endsWith("]"), + s"Expected hostname or IPv6 IP enclosed in [] but got $host") + } else { + assert(host != null && host.indexOf(':') == -1, s"Expected hostname or IP but got $host") + } } def checkHostPort(hostPort: String): Unit = { - assert(hostPort != null && hostPort.indexOf(':') != -1, - s"Expected host and port but got $hostPort") + if (hostPort != null && hostPort.split(":").length > 2) { + assert(hostPort != null && hostPort.indexOf("]:") != -1, + s"Expected host and port but got $hostPort") + } else { + assert(hostPort != null && hostPort.indexOf(':') != -1, + s"Expected host and port but got $hostPort") + } } // Typically, this will be of order of number of nodes in cluster @@ -1046,18 +1060,30 @@ private[spark] object Utils extends Logging { return cached } - val indx: Int = hostPort.lastIndexOf(':') - // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... - // but then hadoop does not support ipv6 right now. - // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 - if (-1 == indx) { + def setDefaultPortValue: (String, Int) = { val retval = (hostPort, 0) hostPortParseResults.put(hostPort, retval) - return retval + retval + } + // checks if the hostport contains IPV6 ip and parses the host, port + if (hostPort != null && hostPort.split(":").length > 2) { + val indx: Int = hostPort.lastIndexOf("]:") + if (-1 == indx) { + return setDefaultPortValue + } + val port = hostPort.substring(indx + 2).trim() + val retval = (hostPort.substring(0, indx + 1).trim(), if (port.isEmpty) 0 else port.toInt) + hostPortParseResults.putIfAbsent(hostPort, retval) + } else { + val indx: Int = hostPort.lastIndexOf(':') + if (-1 == indx) { + return setDefaultPortValue + } + val port = hostPort.substring(indx + 1).trim() + val retval = (hostPort.substring(0, indx).trim(), if (port.isEmpty) 0 else port.toInt) + hostPortParseResults.putIfAbsent(hostPort, retval) } - val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt) - hostPortParseResults.putIfAbsent(hostPort, retval) hostPortParseResults.get(hostPort) } @@ -1716,7 +1742,7 @@ private[spark] object Utils extends Logging { if (inWord || inDoubleQuote || inSingleQuote) { endWord() } - buf + buf.toSeq } /* Calculates 'x' modulo 'mod', takes to consideration sign of x, diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index cc97bbfa7201f..dc39170ecf382 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -659,7 +659,7 @@ private[spark] class ExternalSorter[K, V, C]( } } else { // Merge spilled and in-memory data - merge(spills, destructiveIterator( + merge(spills.toSeq, destructiveIterator( collection.partitionedDestructiveSortedIterator(comparator))) } } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__status___offset___length_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__status___offset___length_expectation.json new file mode 100644 index 0000000000000..28509e33c5dcc --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__status___offset___length_expectation.json @@ -0,0 +1,99 @@ +[ { + "taskId" : 1, + "index" : 1, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:06.502GMT", + "duration" : 421, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 31, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 350, + "executorCpuTime" : 0, + "resultSize" : 2010, + "jvmGcTime" : 7, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 60488, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 3934399, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 40, + "gettingResultTime" : 0 +}, { + "taskId" : 2, + "index" : 2, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:06.503GMT", + "duration" : 419, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 32, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 348, + "executorCpuTime" : 0, + "resultSize" : 2010, + "jvmGcTime" : 7, + "resultSerializationTime" : 2, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 60488, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 89885, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 37, + "gettingResultTime" : 0 +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__status___sortBy_short_names__runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__status___sortBy_short_names__runtime_expectation.json new file mode 100644 index 0000000000000..01eef1b565bf6 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__status___sortBy_short_names__runtime_expectation.json @@ -0,0 +1,981 @@ +[ { + "taskId" : 40, + "index" : 40, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.197GMT", + "duration" : 24, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 14, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 94792, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 +}, { + "taskId" : 41, + "index" : 41, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.200GMT", + "duration" : 24, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 16, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 90765, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 +}, { + "taskId" : 43, + "index" : 43, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.204GMT", + "duration" : 39, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 16, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 171516, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 21, + "gettingResultTime" : 0 +}, { + "taskId" : 57, + "index" : 57, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.257GMT", + "duration" : 21, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 16, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 96849, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 2, + "gettingResultTime" : 0 +}, { + "taskId" : 58, + "index" : 58, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.263GMT", + "duration" : 23, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 16, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 97521, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 +}, { + "taskId" : 68, + "index" : 68, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.306GMT", + "duration" : 22, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 16, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 101750, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 +}, { + "taskId" : 86, + "index" : 86, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.374GMT", + "duration" : 28, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 16, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 95848, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 +}, { + "taskId" : 32, + "index" : 32, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.148GMT", + "duration" : 33, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 17, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 89603, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 +}, { + "taskId" : 39, + "index" : 39, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.180GMT", + "duration" : 32, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 17, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 98748, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 +}, { + "taskId" : 42, + "index" : 42, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.203GMT", + "duration" : 42, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 10, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 17, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 103713, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 15, + "gettingResultTime" : 0 +}, { + "taskId" : 51, + "index" : 51, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.242GMT", + "duration" : 21, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 17, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 96013, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 2, + "gettingResultTime" : 0 +}, { + "taskId" : 59, + "index" : 59, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.265GMT", + "duration" : 23, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 17, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 100753, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 +}, { + "taskId" : 63, + "index" : 63, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.276GMT", + "duration" : 40, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 20, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 17, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 5, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 102779, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 +}, { + "taskId" : 87, + "index" : 87, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.374GMT", + "duration" : 36, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 12, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 17, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 102159, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 7, + "gettingResultTime" : 0 +}, { + "taskId" : 90, + "index" : 90, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.385GMT", + "duration" : 23, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 17, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 98472, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 +}, { + "taskId" : 99, + "index" : 99, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.426GMT", + "duration" : 22, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 17, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70565, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 133964, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 +}, { + "taskId" : 44, + "index" : 44, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.205GMT", + "duration" : 37, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 18, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 98293, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 +}, { + "taskId" : 47, + "index" : 47, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.212GMT", + "duration" : 33, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 18, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 103015, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 +}, { + "taskId" : 50, + "index" : 50, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.240GMT", + "duration" : 26, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 18, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 90836, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 +}, { + "taskId" : 52, + "index" : 52, + "attempt" : 0, + "launchTime" : "2015-05-06T13:03:07.243GMT", + "duration" : 28, + "executorId" : "driver", + "host" : "localhost", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 18, + "executorCpuTime" : 0, + "resultSize" : 2065, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 70564, + "recordsRead" : 10000 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 1710, + "writeTime" : 89664, + "recordsWritten" : 10 + } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__status_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__status_expectation.json new file mode 100644 index 0000000000000..9896aceb275de --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__status_expectation.json @@ -0,0 +1,531 @@ +[ { + "taskId" : 1, + "index" : 1, + "attempt" : 0, + "launchTime" : "2016-11-15T23:20:44.052GMT", + "duration" : 675, + "executorId" : "0", + "host" : "172.22.0.111", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 494, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 30, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } + }, + "executorLogs" : { + "stdout" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stdout", + "stderr" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stderr" + }, + "schedulerDelay" : 181, + "gettingResultTime" : 0 +}, { + "taskId" : 3, + "index" : 3, + "attempt" : 0, + "launchTime" : "2016-11-15T23:20:44.053GMT", + "duration" : 725, + "executorId" : "2", + "host" : "172.22.0.111", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 456, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 32, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } + }, + "executorLogs" : { + "stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr" + }, + "schedulerDelay" : 269, + "gettingResultTime" : 0 +}, { + "taskId" : 5, + "index" : 5, + "attempt" : 0, + "launchTime" : "2016-11-15T23:20:44.055GMT", + "duration" : 665, + "executorId" : "0", + "host" : "172.22.0.111", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 495, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 30, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } + }, + "executorLogs" : { + "stdout" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stdout", + "stderr" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stderr" + }, + "schedulerDelay" : 170, + "gettingResultTime" : 0 +}, { + "taskId" : 7, + "index" : 7, + "attempt" : 0, + "launchTime" : "2016-11-15T23:20:44.056GMT", + "duration" : 685, + "executorId" : "2", + "host" : "172.22.0.111", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 448, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 32, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } + }, + "executorLogs" : { + "stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr" + }, + "schedulerDelay" : 237, + "gettingResultTime" : 0 +}, { + "taskId" : 9, + "index" : 9, + "attempt" : 0, + "launchTime" : "2016-11-15T23:20:44.057GMT", + "duration" : 732, + "executorId" : "0", + "host" : "172.22.0.111", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 503, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 30, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } + }, + "executorLogs" : { + "stdout" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stdout", + "stderr" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stderr" + }, + "schedulerDelay" : 229, + "gettingResultTime" : 0 +}, { + "taskId" : 11, + "index" : 11, + "attempt" : 0, + "launchTime" : "2016-11-15T23:20:44.058GMT", + "duration" : 678, + "executorId" : "2", + "host" : "172.22.0.111", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 451, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 32, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } + }, + "executorLogs" : { + "stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr" + }, + "schedulerDelay" : 227, + "gettingResultTime" : 0 +}, { + "taskId" : 13, + "index" : 13, + "attempt" : 0, + "launchTime" : "2016-11-15T23:20:44.060GMT", + "duration" : 669, + "executorId" : "0", + "host" : "172.22.0.111", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 494, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 30, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } + }, + "executorLogs" : { + "stdout" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stdout", + "stderr" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stderr" + }, + "schedulerDelay" : 175, + "gettingResultTime" : 0 +}, { + "taskId" : 15, + "index" : 15, + "attempt" : 0, + "launchTime" : "2016-11-15T23:20:44.065GMT", + "duration" : 672, + "executorId" : "2", + "host" : "172.22.0.111", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 446, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 32, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } + }, + "executorLogs" : { + "stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr" + }, + "schedulerDelay" : 226, + "gettingResultTime" : 0 +}, { + "taskId" : 19, + "index" : 11, + "attempt" : 1, + "launchTime" : "2016-11-15T23:20:44.736GMT", + "duration" : 13, + "executorId" : "2", + "host" : "172.22.0.111", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 2, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } + }, + "executorLogs" : { + "stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr" + }, + "schedulerDelay" : 11, + "gettingResultTime" : 0 +}, { + "taskId" : 20, + "index" : 15, + "attempt" : 1, + "launchTime" : "2016-11-15T23:20:44.737GMT", + "duration" : 19, + "executorId" : "2", + "host" : "172.22.0.111", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 10, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } + }, + "executorLogs" : { + "stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr" + }, + "schedulerDelay" : 9, + "gettingResultTime" : 0 +} ] diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index a75cf3f0381df..d701cb65460af 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -157,7 +157,7 @@ private class SaveInfoListener extends SparkListener { def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.values.flatten.toSeq def getCompletedTaskInfos(stageId: StageId, stageAttemptId: StageAttemptId): Seq[TaskInfo] = - completedTaskInfos.getOrElse((stageId, stageAttemptId), Seq.empty[TaskInfo]) + completedTaskInfos.getOrElse((stageId, stageAttemptId), Seq.empty[TaskInfo]).toSeq /** * If `jobCompletionCallback` is set, block until the next call has finished. diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index a69381d18e3b6..21090e98ea285 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -194,7 +194,7 @@ trait RDDCheckpointTester { self: SparkFunSuite => /** * Serialize and deserialize an object. This is useful to verify the objects * contents after deserialization (e.g., the contents of an RDD split after - * it is sent to a slave along with a task) + * it is sent to an executor along with a task) */ protected def serializeDeserialize[T](obj: T): T = { val bytes = Utils.serialize(obj) diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 92ed24408384f..81530a8fda84d 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -291,14 +291,14 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { val shuffleIds = 0 until sc.newShuffleId val broadcastIds = broadcastBuffer.map(_.id) - val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds.toSeq) runGC() intercept[Exception] { preGCTester.assertCleanup()(timeout(1.second)) } // Test that GC triggers the cleanup of all variables after the dereferencing them - val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds.toSeq) broadcastBuffer.clear() rddBuffer.clear() runGC() @@ -309,7 +309,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { assert(sc.env.blockManager.master.getMatchingBlockIds({ case BroadcastBlockId(`taskClosureBroadcastId`, _) => true case _ => false - }, askSlaves = true).isEmpty) + }, askStorageEndpoints = true).isEmpty) } test("automatically cleanup RDD + shuffle + broadcast in distributed mode") { @@ -331,14 +331,14 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { val shuffleIds = 0 until sc.newShuffleId val broadcastIds = broadcastBuffer.map(_.id) - val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds.toSeq) runGC() intercept[Exception] { preGCTester.assertCleanup()(timeout(1.second)) } // Test that GC triggers the cleanup of all variables after the dereferencing them - val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds.toSeq) broadcastBuffer.clear() rddBuffer.clear() runGC() @@ -349,7 +349,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { assert(sc.env.blockManager.master.getMatchingBlockIds({ case BroadcastBlockId(`taskClosureBroadcastId`, _) => true case _ => false - }, askSlaves = true).isEmpty) + }, askStorageEndpoints = true).isEmpty) } } @@ -528,7 +528,7 @@ class CleanerTester( blockManager.master.getMatchingBlockIds( _ match { case RDDBlockId(`rddId`, _) => true case _ => false - }, askSlaves = true) + }, askStorageEndpoints = true) } private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = { @@ -536,14 +536,14 @@ class CleanerTester( case ShuffleBlockId(`shuffleId`, _, _) => true case ShuffleIndexBlockId(`shuffleId`, _, _) => true case _ => false - }, askSlaves = true) + }, askStorageEndpoints = true) } private def getBroadcastBlocks(broadcastId: Long): Seq[BlockId] = { blockManager.master.getMatchingBlockIds( _ match { case BroadcastBlockId(`broadcastId`, _) => true case _ => false - }, askSlaves = true) + }, askStorageEndpoints = true) } private def blockManager = sc.env.blockManager diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 4d157b9607000..27862806c0840 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -45,11 +45,11 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex // this test will hang. Correct behavior is that executors don't crash but fail tasks // and the scheduler throws a SparkException. - // numSlaves must be less than numPartitions - val numSlaves = 3 + // numWorkers must be less than numPartitions + val numWorkers = 3 val numPartitions = 10 - sc = new SparkContext("local-cluster[%s,1,1024]".format(numSlaves), "test") + sc = new SparkContext("local-cluster[%s,1,1024]".format(numWorkers), "test") val data = sc.parallelize(1 to 100, numPartitions). map(x => throw new NotSerializableExn(new NotSerializableClass)) intercept[SparkException] { @@ -69,10 +69,10 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex ) masterStrings.foreach { - case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => - assert(numSlaves.toInt == 2) - assert(coresPerSlave.toInt == 1) - assert(memoryPerSlave.toInt == 1024) + case LOCAL_CLUSTER_REGEX(numWorkers, coresPerWorker, memoryPerWorker) => + assert(numWorkers.toInt == 2) + assert(coresPerWorker.toInt == 1) + assert(memoryPerWorker.toInt == 1024) } } @@ -227,7 +227,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(data.count() === size) assert(data.count() === size) // ensure only a subset of partitions were cached - val rddBlocks = sc.env.blockManager.master.getMatchingBlockIds(_.isRDD, askSlaves = true) + val rddBlocks = sc.env.blockManager.master.getMatchingBlockIds(_.isRDD, + askStorageEndpoints = true) assert(rddBlocks.size === 0, s"expected no RDD blocks, found ${rddBlocks.size}") } @@ -244,7 +245,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(data.count() === size) assert(data.count() === size) // ensure only a subset of partitions were cached - val rddBlocks = sc.env.blockManager.master.getMatchingBlockIds(_.isRDD, askSlaves = true) + val rddBlocks = sc.env.blockManager.master.getMatchingBlockIds(_.isRDD, + askStorageEndpoints = true) assert(rddBlocks.size > 0, "no RDD blocks found") assert(rddBlocks.size < numPartitions, s"too many RDD blocks found, expected <$numPartitions") } diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index c217419f4092e..65391db405a55 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -72,12 +72,12 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi sc.env.blockManager.externalShuffleServiceEnabled should equal(true) sc.env.blockManager.blockStoreClient.getClass should equal(classOf[ExternalBlockStoreClient]) - // In a slow machine, one slave may register hundreds of milliseconds ahead of the other one. - // If we don't wait for all slaves, it's possible that only one executor runs all jobs. Then + // In a slow machine, one executor may register hundreds of milliseconds ahead of the other one. + // If we don't wait for all executors, it's possible that only one executor runs all jobs. Then // all shuffle blocks will be in this executor, ShuffleBlockFetcherIterator will directly fetch // local blocks from the local BlockManager and won't send requests to ExternalShuffleService. // In this case, we won't receive FetchFailed. And it will make this test fail. - // Therefore, we should wait until all slaves are up + // Therefore, we should wait until all executors are up TestUtils.waitUntilExecutorsUp(sc, 2, 60000) val rdd = sc.parallelize(0 until 1000, 10) @@ -109,12 +109,12 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi sc.env.blockManager.hostLocalDirManager.isDefined should equal(true) sc.env.blockManager.blockStoreClient.getClass should equal(classOf[ExternalBlockStoreClient]) - // In a slow machine, one slave may register hundreds of milliseconds ahead of the other one. - // If we don't wait for all slaves, it's possible that only one executor runs all jobs. Then + // In a slow machine, one executor may register hundreds of milliseconds ahead of the other one. + // If we don't wait for all executors, it's possible that only one executor runs all jobs. Then // all shuffle blocks will be in this executor, ShuffleBlockFetcherIterator will directly fetch // local blocks from the local BlockManager and won't send requests to ExternalShuffleService. // In this case, we won't receive FetchFailed. And it will make this test fail. - // Therefore, we should wait until all slaves are up + // Therefore, we should wait until all executors are up TestUtils.waitUntilExecutorsUp(sc, 2, 60000) val rdd = sc.parallelize(0 until 1000, 10) diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 312691302b064..a2e70b23a3e5d 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -261,7 +261,7 @@ class HeartbeatReceiverSuite // We may receive undesired SparkListenerExecutorAdded from LocalSchedulerBackend, // so exclude it from the map. See SPARK-10800. heartbeatReceiver.invokePrivate(_executorLastSeen()). - filterKeys(_ != SparkContext.DRIVER_IDENTIFIER) + filterKeys(_ != SparkContext.DRIVER_IDENTIFIER).toMap } } @@ -287,6 +287,8 @@ private class FakeSchedulerBackend( resourceProfileManager: ResourceProfileManager) extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { + def this() = this(null, null, null, null) + protected override def doRequestTotalExecutors( resourceProfileToTotalExecs: Map[ResourceProfile, Int]): Future[Boolean] = { clusterManagerEndpoint.ask[Boolean]( diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 5399d868f46f1..f2b81e5153ae4 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -220,7 +220,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { super.registerAccumulatorForCleanup(a) } - def accumsRegisteredForCleanup: Seq[Long] = accumsRegistered.toArray + def accumsRegisteredForCleanup: Seq[Long] = accumsRegistered.toSeq } } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 630ffd9baa06e..b5b68f639ffc9 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -136,21 +136,21 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + val mapWorkerRpcEnv = createRpcEnv("spark-worker", hostname, 0, new SecurityManager(conf)) + val mapWorkerTracker = new MapOutputTrackerWorker(conf) + mapWorkerTracker.trackerEndpoint = + mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) - slaveTracker.updateEpoch(masterTracker.getEpoch) + mapWorkerTracker.updateEpoch(masterTracker.getEpoch) // This is expected to fail because no outputs have been registered for the shuffle. - intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + intercept[FetchFailedException] { mapWorkerTracker.getMapSizesByExecutorId(10, 0) } val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L), 5)) - slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === + mapWorkerTracker.updateEpoch(masterTracker.getEpoch) + assert(mapWorkerTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 5, 0), size1000, 0))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) @@ -158,17 +158,17 @@ class MapOutputTrackerSuite extends SparkFunSuite { val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput) - slaveTracker.updateEpoch(masterTracker.getEpoch) - intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + mapWorkerTracker.updateEpoch(masterTracker.getEpoch) + intercept[FetchFailedException] { mapWorkerTracker.getMapSizesByExecutorId(10, 0) } // failure should be cached - intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + intercept[FetchFailedException] { mapWorkerTracker.getMapSizesByExecutorId(10, 0) } assert(0 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.stop() - slaveTracker.stop() + mapWorkerTracker.stop() rpcEnv.shutdown() - slaveRpcEnv.shutdown() + mapWorkerRpcEnv.shutdown() } test("remote fetch below max RPC message size") { diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 9e39271bdf9ee..3d6690cb85348 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -182,7 +182,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2) val results = new CoGroupedRDD[Int](Seq(pairs1, pairs2), new HashPartitioner(2)) - .map(p => (p._1, p._2.map(_.toArray))) + .map(p => (p._1, p._2.map(_.toSeq))) .collectAsMap() assert(results(1)(0).length === 3) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 30237fd576830..d111bb33ce8ff 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -934,6 +934,18 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } } + + test("SPARK-32160: Disallow to create SparkContext in executors") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local-cluster[3, 1, 1024]")) + + val error = intercept[SparkException] { + sc.range(0, 1).foreach { _ => + new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + } + }.getMessage() + + assert(error.contains("SparkContext should only be created and accessed on the driver.")) + } } object SparkContextSuite { diff --git a/core/src/test/scala/org/apache/spark/ThreadAudit.scala b/core/src/test/scala/org/apache/spark/ThreadAudit.scala index 44d1f220bf6b1..1e2917621fa79 100644 --- a/core/src/test/scala/org/apache/spark/ThreadAudit.scala +++ b/core/src/test/scala/org/apache/spark/ThreadAudit.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging */ trait ThreadAudit extends Logging { - val threadWhiteList = Set( + val threadExcludeList = Set( /** * Netty related internal threads. * These are excluded because their lifecycle is handled by the netty itself @@ -108,7 +108,7 @@ trait ThreadAudit extends Logging { if (threadNamesSnapshot.nonEmpty) { val remainingThreadNames = runningThreadNames().diff(threadNamesSnapshot) - .filterNot { s => threadWhiteList.exists(s.matches(_)) } + .filterNot { s => threadExcludeList.exists(s.matches(_)) } if (remainingThreadNames.nonEmpty) { logWarning(s"\n\n===== POSSIBLE THREAD LEAK IN SUITE $shortSuiteName, " + s"thread names: ${remainingThreadNames.mkString(", ")} =====\n") diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index a6776ee077894..5e8b25f425166 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -68,14 +68,14 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio } encryptionTest("Accessing TorrentBroadcast variables in a local cluster") { conf => - val numSlaves = 4 + val numWorkers = 4 conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer") conf.set(config.BROADCAST_COMPRESS, true) - sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numWorkers), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + val results = sc.parallelize(1 to numWorkers).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numWorkers).map(x => (x, 10)).toSet) } test("TorrentBroadcast's blockifyObject and unblockifyObject are inverses") { @@ -99,12 +99,12 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio } test("Test Lazy Broadcast variables with TorrentBroadcast") { - val numSlaves = 2 - sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test") - val rdd = sc.parallelize(1 to numSlaves) + val numWorkers = 2 + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numWorkers), "test") + val rdd = sc.parallelize(1 to numWorkers) val results = new DummyBroadcastClass(rdd).doSomething() - assert(results.toSet === (1 to numSlaves).map(x => (x, false)).toSet) + assert(results.toSet === (1 to numWorkers).map(x => (x, false)).toSet) } test("Unpersisting TorrentBroadcast on executors only in local mode") { @@ -196,27 +196,27 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio */ private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean): Unit = { - val numSlaves = if (distributed) 2 else 0 + val numWorkers = if (distributed) 2 else 0 // Verify that blocks are persisted only on the driver def afterCreation(broadcastId: Long, bmm: BlockManagerMaster): Unit = { var blockId = BroadcastBlockId(broadcastId) - var statuses = bmm.getBlockStatus(blockId, askSlaves = true) + var statuses = bmm.getBlockStatus(blockId, askStorageEndpoints = true) assert(statuses.size === 1) blockId = BroadcastBlockId(broadcastId, "piece0") - statuses = bmm.getBlockStatus(blockId, askSlaves = true) + statuses = bmm.getBlockStatus(blockId, askStorageEndpoints = true) assert(statuses.size === 1) } // Verify that blocks are persisted in both the executors and the driver def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster): Unit = { var blockId = BroadcastBlockId(broadcastId) - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === numSlaves + 1) + val statuses = bmm.getBlockStatus(blockId, askStorageEndpoints = true) + assert(statuses.size === numWorkers + 1) blockId = BroadcastBlockId(broadcastId, "piece0") - assert(statuses.size === numSlaves + 1) + assert(statuses.size === numWorkers + 1) } // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver @@ -224,16 +224,16 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster): Unit = { var blockId = BroadcastBlockId(broadcastId) var expectedNumBlocks = if (removeFromDriver) 0 else 1 - var statuses = bmm.getBlockStatus(blockId, askSlaves = true) + var statuses = bmm.getBlockStatus(blockId, askStorageEndpoints = true) assert(statuses.size === expectedNumBlocks) blockId = BroadcastBlockId(broadcastId, "piece0") expectedNumBlocks = if (removeFromDriver) 0 else 1 - statuses = bmm.getBlockStatus(blockId, askSlaves = true) + statuses = bmm.getBlockStatus(blockId, askStorageEndpoints = true) assert(statuses.size === expectedNumBlocks) } - testUnpersistBroadcast(distributed, numSlaves, afterCreation, + testUnpersistBroadcast(distributed, numWorkers, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -248,7 +248,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio */ private def testUnpersistBroadcast( distributed: Boolean, - numSlaves: Int, // used only when distributed = true + numWorkers: Int, // used only when distributed = true afterCreation: (Long, BlockManagerMaster) => Unit, afterUsingBroadcast: (Long, BlockManagerMaster) => Unit, afterUnpersist: (Long, BlockManagerMaster) => Unit, @@ -256,10 +256,10 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio sc = if (distributed) { val _sc = - new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test") + new SparkContext("local-cluster[%d, 1, 1024]".format(numWorkers), "test") // Wait until all salves are up try { - TestUtils.waitUntilExecutorsUp(_sc, numSlaves, 60000) + TestUtils.waitUntilExecutorsUp(_sc, numWorkers, 60000) _sc } catch { case e: Throwable => @@ -278,7 +278,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio // Use broadcast variable on all executors val partitions = 10 - assert(partitions > numSlaves) + assert(partitions > numWorkers) val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) afterUsingBroadcast(broadcast.id, blockManagerMaster) diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala index 42b8cde650390..b986be03e965c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala @@ -317,7 +317,7 @@ private[deploy] object IvyTestUtils { val rFiles = createRFiles(root, className, artifact.groupId) allFiles.append(rFiles: _*) } - val jarFile = packJar(jarPath, artifact, allFiles, useIvyLayout, withR) + val jarFile = packJar(jarPath, artifact, allFiles.toSeq, useIvyLayout, withR) assert(jarFile.exists(), "Problem creating Jar file") val descriptor = createDescriptor(tempPath, artifact, dependencies, useIvyLayout) assert(descriptor.exists(), "Problem creating Pom file") diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index eeccf56cbf02e..354e6eb2138d9 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -106,6 +106,9 @@ object JsonConstants { """ |{"id":"id","starttime":3,"name":"name", |"cores":0,"user":"%s", + |"memoryperexecutor":1234, + |"resourcesperexecutor":[{"name":"gpu", + |"amount":3},{"name":"fpga","amount":3}], |"memoryperslave":1234, |"resourcesperslave":[{"name":"gpu", |"amount":3},{"name":"fpga","amount":3}], @@ -132,7 +135,8 @@ object JsonConstants { val appDescJsonStr = """ - |{"name":"name","cores":4,"memoryperslave":1234,"resourcesperslave":[], + |{"name":"name","cores":4,"memoryperexecutor":1234,"resourcesperexecutor":[], + |"memoryperslave":1234,"resourcesperslave":[], |"user":"%s","command":"Command(mainClass,List(arg1, arg2),Map(),List(),List(),List())"} """.format(System.getProperty("user.name", "")).stripMargin diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index fd2d1f56ed9b6..fd3d4bcf62f69 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -1210,17 +1210,17 @@ class SparkSubmitSuite testRemoteResources(enableHttpFs = true) } - test("force download from blacklisted schemes") { - testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("http")) + test("force download from forced schemes") { + testRemoteResources(enableHttpFs = true, forceDownloadSchemes = Seq("http")) } test("force download for all the schemes") { - testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("*")) + testRemoteResources(enableHttpFs = true, forceDownloadSchemes = Seq("*")) } private def testRemoteResources( enableHttpFs: Boolean, - blacklistSchemes: Seq[String] = Nil): Unit = { + forceDownloadSchemes: Seq[String] = Nil): Unit = { val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) if (enableHttpFs) { @@ -1237,8 +1237,8 @@ class SparkSubmitSuite val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}" - val forceDownloadArgs = if (blacklistSchemes.nonEmpty) { - Seq("--conf", s"spark.yarn.dist.forceDownloadSchemes=${blacklistSchemes.mkString(",")}") + val forceDownloadArgs = if (forceDownloadSchemes.nonEmpty) { + Seq("--conf", s"spark.yarn.dist.forceDownloadSchemes=${forceDownloadSchemes.mkString(",")}") } else { Nil } @@ -1256,19 +1256,19 @@ class SparkSubmitSuite val jars = conf.get("spark.yarn.dist.jars").split(",").toSet - def isSchemeBlacklisted(scheme: String) = { - blacklistSchemes.contains("*") || blacklistSchemes.contains(scheme) + def isSchemeForcedDownload(scheme: String) = { + forceDownloadSchemes.contains("*") || forceDownloadSchemes.contains(scheme) } - if (!isSchemeBlacklisted("s3")) { + if (!isSchemeForcedDownload("s3")) { assert(jars.contains(tmpS3JarPath)) } - if (enableHttpFs && blacklistSchemes.isEmpty) { + if (enableHttpFs && forceDownloadSchemes.isEmpty) { // If Http FS is supported by yarn service, the URI of remote http resource should // still be remote. assert(jars.contains(tmpHttpJarPath)) - } else if (!enableHttpFs || isSchemeBlacklisted("http")) { + } else if (!enableHttpFs || isSchemeForcedDownload("http")) { // If Http FS is not supported by yarn service, or http scheme is configured to be force // downloading, the URI of remote http resource should be changed to a local one. val jarName = new File(tmpHttpJar.toURI).getName diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 57cbda3c0620d..c7c3ad27675fa 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -545,7 +545,7 @@ class StandaloneDynamicAllocationSuite // will not timeout anything related to executors. .set(config.Network.NETWORK_TIMEOUT.key, "2h") .set(config.EXECUTOR_HEARTBEAT_INTERVAL.key, "1h") - .set(config.STORAGE_BLOCKMANAGER_SLAVE_TIMEOUT.key, "1h") + .set(config.STORAGE_BLOCKMANAGER_HEARTBEAT_TIMEOUT.key, "1h") } /** Make a master to which our application will send executor requests. */ diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index c2f34fc3a95ed..ade03a0095c19 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -1117,7 +1117,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with Matchers with Logging { } } - test("SPARK-24948: blacklist files we don't have read permission on") { + test("SPARK-24948: ignore files we don't have read permission on") { val clock = new ManualClock(1533132471) val provider = new FsHistoryProvider(createTestConf(), clock) val accessDenied = newLogFile("accessDenied", None, inProgress = false) @@ -1137,17 +1137,17 @@ class FsHistoryProviderSuite extends SparkFunSuite with Matchers with Logging { updateAndCheck(mockedProvider) { list => list.size should be(1) } - // Doing 2 times in order to check the blacklist filter too + // Doing 2 times in order to check the inaccessibleList filter too updateAndCheck(mockedProvider) { list => list.size should be(1) } val accessDeniedPath = new Path(accessDenied.getPath) - assert(mockedProvider.isBlacklisted(accessDeniedPath)) + assert(!mockedProvider.isAccessible(accessDeniedPath)) clock.advance(24 * 60 * 60 * 1000 + 1) // add a bit more than 1d isReadable = true mockedProvider.cleanLogs() updateAndCheck(mockedProvider) { list => - assert(!mockedProvider.isBlacklisted(accessDeniedPath)) + assert(mockedProvider.isAccessible(accessDeniedPath)) assert(list.exists(_.name == "accessDenied")) assert(list.exists(_.name == "accessGranted")) list.size should be(2) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala index f78469e132490..9004e86323691 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala @@ -158,4 +158,56 @@ class HistoryServerDiskManagerSuite extends SparkFunSuite with BeforeAndAfter { assert(manager.approximateSize(50L, true) > 50L) } + test("SPARK-32024: update ApplicationStoreInfo.size during initializing") { + val manager = mockManager() + val leaseA = manager.lease(2) + doReturn(3L).when(manager).sizeOf(meq(leaseA.tmpPath)) + val dstPathA = manager.appStorePath("app1", None) + doReturn(3L).when(manager).sizeOf(meq(dstPathA)) + val dstA = leaseA.commit("app1", None) + assert(manager.free() === 0) + assert(manager.committed() === 3) + // Listing store tracks dstA now. + assert(store.read(classOf[ApplicationStoreInfo], dstA.getAbsolutePath).size === 3) + + // Simulate: service restarts, new disk manager (manager1) is initialized. + val manager1 = mockManager() + // Simulate: event KVstore compaction before restart, directory size reduces. + doReturn(2L).when(manager1).sizeOf(meq(dstA)) + doReturn(2L).when(manager1).sizeOf(meq(new File(testDir, "apps"))) + manager1.initialize() + // "ApplicationStoreInfo.size" is updated for dstA. + assert(store.read(classOf[ApplicationStoreInfo], dstA.getAbsolutePath).size === 2) + assert(manager1.free() === 1) + // If "ApplicationStoreInfo.size" is not correctly updated, "IllegalStateException" + // would be thrown. + val leaseB = manager1.lease(2) + assert(manager1.free() === 1) + doReturn(2L).when(manager1).sizeOf(meq(leaseB.tmpPath)) + val dstPathB = manager.appStorePath("app2", None) + doReturn(2L).when(manager1).sizeOf(meq(dstPathB)) + val dstB = leaseB.commit("app2", None) + assert(manager1.committed() === 2) + // Listing store tracks dstB only, dstA is evicted by "makeRoom()". + assert(store.read(classOf[ApplicationStoreInfo], dstB.getAbsolutePath).size === 2) + + val manager2 = mockManager() + // Simulate: cache entities are written after replaying, directory size increases. + doReturn(3L).when(manager2).sizeOf(meq(dstB)) + doReturn(3L).when(manager2).sizeOf(meq(new File(testDir, "apps"))) + manager2.initialize() + // "ApplicationStoreInfo.size" is updated for dstB. + assert(store.read(classOf[ApplicationStoreInfo], dstB.getAbsolutePath).size === 3) + assert(manager2.free() === 0) + val leaseC = manager2.lease(2) + doReturn(2L).when(manager2).sizeOf(meq(leaseC.tmpPath)) + val dstPathC = manager.appStorePath("app3", None) + doReturn(2L).when(manager2).sizeOf(meq(dstPathC)) + val dstC = leaseC.commit("app3", None) + assert(manager2.free() === 1) + assert(manager2.committed() === 2) + // Listing store tracks dstC only, dstB is evicted by "makeRoom()". + assert(store.read(classOf[ApplicationStoreInfo], dstC.getAbsolutePath).size === 2) + } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 8737cd5bb3241..39b339caea385 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -154,6 +154,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "applications/local-1430917381534/stages/0/0/taskList?sortBy=-runtime", "stage task list w/ sortBy short names: runtime" -> "applications/local-1430917381534/stages/0/0/taskList?sortBy=runtime", + "stage task list w/ status" -> + "applications/app-20161115172038-0000/stages/0/0/taskList?status=failed", + "stage task list w/ status & offset & length" -> + "applications/local-1430917381534/stages/0/0/taskList?status=success&offset=1&length=2", + "stage task list w/ status & sortBy short names: runtime" -> + "applications/local-1430917381534/stages/0/0/taskList?status=success&sortBy=runtime", "stage list with accumulable json" -> "applications/local-1426533911241/1/stages", "stage with accumulable json" -> "applications/local-1426533911241/1/stages/0/0", @@ -313,8 +319,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers all (directSiteRelativeLinks) should not startWith (knoxBaseUrl) } - // TODO (SPARK-31723): re-enable it - ignore("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") { + test("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") val page = new HistoryPage(server) val request = mock[HttpServletRequest] @@ -644,6 +649,19 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val actualContentType = conn.getContentType assert(actualContentType === expectedContentType) } + + test("Redirect to the root page when accessed to /history/") { + val port = server.boundPort + val url = new URL(s"http://localhost:$port/history/") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("GET") + conn.setUseCaches(false) + conn.setDefaultUseCaches(false) + conn.setInstanceFollowRedirects(false) + conn.connect() + assert(conn.getResponseCode === 302) + assert(conn.getHeaderField("Location") === s"http://localhost:$port/") + } } object HistoryServerSuite { diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 0cf573c2490b3..91128af82b022 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -685,7 +685,8 @@ class MasterSuite extends SparkFunSuite } } - test("SPARK-27510: Master should avoid dead loop while launching executor failed in Worker") { + // TODO(SPARK-32250): Enable the test back. It is flaky in GitHub Actions. + ignore("SPARK-27510: Master should avoid dead loop while launching executor failed in Worker") { val master = makeAliveMaster() var worker: MockExecutorLaunchFailWorker = null try { diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index 10f4bbcf7f48b..879107350bb52 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -140,7 +140,7 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { assert(slices(i).isInstanceOf[Range]) val range = slices(i).asInstanceOf[Range] assert(range.start === i * (N / 40), "slice " + i + " start") - assert(range.end === (i + 1) * (N / 40), "slice " + i + " end") + assert(range.last === (i + 1) * (N / 40) - 1, "slice " + i + " end") assert(range.step === 1, "slice " + i + " step") } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 18154d861a731..79f9c1396c87b 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -656,7 +656,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { } test("top with predefined ordering") { - val nums = Array.range(1, 100000) + val nums = Seq.range(1, 100000) val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2) val topK = ints.top(5) assert(topK.size === 5) @@ -1098,7 +1098,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { override def getPartitions: Array[Partition] = Array(new Partition { override def index: Int = 0 }) - override def getDependencies: Seq[Dependency[_]] = mutableDependencies + override def getDependencies: Seq[Dependency[_]] = mutableDependencies.toSeq def addDependency(dep: Dependency[_]): Unit = { mutableDependencies += dep } @@ -1298,19 +1298,15 @@ class SizeBasedCoalescer(val maxSize: Int) extends PartitionCoalescer with Seria val splitSize = fileSplit.getLength if (currentSum + splitSize < maxSize) { addPartition(partition, splitSize) - index += 1 - if (index == partitions.size) { - updateGroups - } } else { - if (currentGroup.partitions.size == 0) { - addPartition(partition, splitSize) - index += 1 - } else { - updateGroups + if (currentGroup.partitions.nonEmpty) { + updateGroups() } + addPartition(partition, splitSize) } + index += 1 } + updateGroups() groups.toArray } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 51d20d3428915..7013832757e38 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -474,7 +474,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assertDataStructuresEmpty() } - test("All shuffle files on the slave should be cleaned up when slave lost") { + test("All shuffle files on the storage endpoint should be cleaned up when it is lost") { // reset the test context with the right shuffle service config afterEach() val conf = new SparkConf() @@ -779,9 +779,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } private val shuffleFileLossTests = Seq( - ("slave lost with shuffle service", SlaveLost("", false), true, false), - ("worker lost with shuffle service", SlaveLost("", true), true, true), - ("worker lost without shuffle service", SlaveLost("", true), false, true), + ("executor process lost with shuffle service", ExecutorProcessLost("", false), true, false), + ("worker lost with shuffle service", ExecutorProcessLost("", true), true, true), + ("worker lost without shuffle service", ExecutorProcessLost("", true), false, true), ("executor failure with shuffle service", ExecutorKilled, true, false), ("executor failure without shuffle service", ExecutorKilled, false, true)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 7c23e4449f461..915035e9eb71c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -325,7 +325,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit 8000L, 5000L, 7000L, 4000L, 6000L, 3000L, 10L, 90L, 2L, 20L) def max(a: Array[Long], b: Array[Long]): Array[Long] = - (a, b).zipped.map(Math.max) + (a, b).zipped.map(Math.max).toArray // calculated metric peaks per stage per executor // metrics sent during stage 0 for each executor diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExecutorResourceInfoSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExecutorResourceInfoSuite.scala index 388d4e25a06cf..e392ff53e02c9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExecutorResourceInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExecutorResourceInfoSuite.scala @@ -26,7 +26,7 @@ class ExecutorResourceInfoSuite extends SparkFunSuite { test("Track Executor Resource information") { // Init Executor Resource. - val info = new ExecutorResourceInfo(GPU, ArrayBuffer("0", "1", "2", "3"), 1) + val info = new ExecutorResourceInfo(GPU, Seq("0", "1", "2", "3"), 1) assert(info.availableAddrs.sorted sameElements Seq("0", "1", "2", "3")) assert(info.assignedAddrs.isEmpty) @@ -43,7 +43,7 @@ class ExecutorResourceInfoSuite extends SparkFunSuite { test("Don't allow acquire address that is not available") { // Init Executor Resource. - val info = new ExecutorResourceInfo(GPU, ArrayBuffer("0", "1", "2", "3"), 1) + val info = new ExecutorResourceInfo(GPU, Seq("0", "1", "2", "3"), 1) // Acquire some addresses. info.acquire(Seq("0", "1")) assert(!info.availableAddrs.contains("1")) @@ -56,7 +56,7 @@ class ExecutorResourceInfoSuite extends SparkFunSuite { test("Don't allow acquire address that doesn't exist") { // Init Executor Resource. - val info = new ExecutorResourceInfo(GPU, ArrayBuffer("0", "1", "2", "3"), 1) + val info = new ExecutorResourceInfo(GPU, Seq("0", "1", "2", "3"), 1) assert(!info.availableAddrs.contains("4")) // Acquire an address that doesn't exist val e = intercept[SparkException] { @@ -67,7 +67,7 @@ class ExecutorResourceInfoSuite extends SparkFunSuite { test("Don't allow release address that is not assigned") { // Init Executor Resource. - val info = new ExecutorResourceInfo(GPU, ArrayBuffer("0", "1", "2", "3"), 1) + val info = new ExecutorResourceInfo(GPU, Seq("0", "1", "2", "3"), 1) // Acquire addresses info.acquire(Array("0", "1")) assert(!info.assignedAddrs.contains("2")) @@ -80,7 +80,7 @@ class ExecutorResourceInfoSuite extends SparkFunSuite { test("Don't allow release address that doesn't exist") { // Init Executor Resource. - val info = new ExecutorResourceInfo(GPU, ArrayBuffer("0", "1", "2", "3"), 1) + val info = new ExecutorResourceInfo(GPU, Seq("0", "1", "2", "3"), 1) assert(!info.assignedAddrs.contains("4")) // Release an address that doesn't exist val e = intercept[SparkException] { @@ -93,7 +93,7 @@ class ExecutorResourceInfoSuite extends SparkFunSuite { val slotSeq = Seq(10, 9, 8, 7, 6, 5, 4, 3, 2, 1) val addresses = ArrayBuffer("0", "1", "2", "3") slotSeq.foreach { slots => - val info = new ExecutorResourceInfo(GPU, addresses, slots) + val info = new ExecutorResourceInfo(GPU, addresses.toSeq, slots) for (_ <- 0 until slots) { addresses.foreach(addr => info.acquire(Seq(addr))) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index d4e8d63b54e5f..270b2c606ad0c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -621,7 +621,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } override def onStageCompleted(stage: SparkListenerStageCompleted): Unit = { - stageInfos(stage.stageInfo) = taskInfoMetrics + stageInfos(stage.stageInfo) = taskInfoMetrics.toSeq taskInfoMetrics = mutable.Buffer.empty } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 2efe6da5e986f..ea44a2d948ca9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -103,7 +103,7 @@ private class MyTaskResultGetter(env: SparkEnv, scheduler: TaskSchedulerImpl) // DirectTaskResults that we receive from the executors private val _taskResults = new ArrayBuffer[DirectTaskResult[_]] - def taskResults: Seq[DirectTaskResult[_]] = _taskResults + def taskResults: Seq[DirectTaskResult[_]] = _taskResults.toSeq override def enqueueSuccessfulTask(tsm: TaskSetManager, tid: Long, data: ByteBuffer): Unit = { // work on a copy since the super class still needs to use the buffer diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index a75bae56229b4..e43be60e956be 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -641,7 +641,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(0 === taskDescriptions2.length) // provide the actual loss reason for executor0 - taskScheduler.executorLost("executor0", SlaveLost("oops")) + taskScheduler.executorLost("executor0", ExecutorProcessLost("oops")) // executor0's tasks should have failed now that the loss reason is known, so offering more // resources should make them be scheduled on the new executor. @@ -1141,7 +1141,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // Now we fail our second executor. The other task can still run on executor1, so make an offer // on that executor, and make sure that the other task (not the failed one) is assigned there. - taskScheduler.executorLost("executor1", SlaveLost("oops")) + taskScheduler.executorLost("executor1", ExecutorProcessLost("oops")) val nextTaskAttempts = taskScheduler.resourceOffers(IndexedSeq(new WorkerOffer("executor0", "host0", 1))).flatten // Note: Its OK if some future change makes this already realize the taskset has become @@ -1273,7 +1273,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(1 === taskDescriptions.length) // mark executor0 as dead - taskScheduler.executorLost("executor0", SlaveLost()) + taskScheduler.executorLost("executor0", ExecutorProcessLost()) assert(!taskScheduler.isExecutorAlive("executor0")) assert(!taskScheduler.hasExecutorsAliveOnHost("host0")) assert(taskScheduler.getExecutorsAliveOnHost("host0").isEmpty) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index e4aad58d25064..95c8197abbf0b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -415,7 +415,7 @@ class TaskSetManagerSuite // Now mark host2 as dead sched.removeExecutor("exec2") - manager.executorLost("exec2", "host2", SlaveLost()) + manager.executorLost("exec2", "host2", ExecutorProcessLost()) // nothing should be chosen assert(manager.resourceOffer("exec1", "host1", ANY)._1 === None) @@ -598,10 +598,10 @@ class TaskSetManagerSuite Array(PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY))) // test if the valid locality is recomputed when the executor is lost sched.removeExecutor("execC") - manager.executorLost("execC", "host2", SlaveLost()) + manager.executorLost("execC", "host2", ExecutorProcessLost()) assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, NO_PREF, ANY))) sched.removeExecutor("execD") - manager.executorLost("execD", "host1", SlaveLost()) + manager.executorLost("execD", "host1", ExecutorProcessLost()) assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY))) } @@ -814,7 +814,7 @@ class TaskSetManagerSuite assert(resubmittedTasks === 0) assert(manager.runningTasks === 1) - manager.executorLost("execB", "host2", new SlaveLost()) + manager.executorLost("execB", "host2", new ExecutorProcessLost()) assert(manager.runningTasks === 0) assert(resubmittedTasks === 0) } @@ -923,7 +923,7 @@ class TaskSetManagerSuite // Make sure schedBackend.killTask(2, "exec3", true, "another attempt succeeded") gets called assert(killTaskCalled) // Host 3 Losts, there's only task 2.0 on it, which killed by task 2.1 - manager.executorLost("exec3", "host3", SlaveLost()) + manager.executorLost("exec3", "host3", ExecutorProcessLost()) // Check the resubmittedTasks assert(resubmittedTasks === 0) } @@ -1044,8 +1044,8 @@ class TaskSetManagerSuite assert(manager.resourceOffer("execB.2", "host2", ANY) !== None) sched.removeExecutor("execA") sched.removeExecutor("execB.2") - manager.executorLost("execA", "host1", SlaveLost()) - manager.executorLost("execB.2", "host2", SlaveLost()) + manager.executorLost("execA", "host1", ExecutorProcessLost()) + manager.executorLost("execB.2", "host2", ExecutorProcessLost()) clock.advance(LOCALITY_WAIT_MS * 4) sched.addExecutor("execC", "host3") manager.executorAdded() @@ -1569,7 +1569,7 @@ class TaskSetManagerSuite assert(resubmittedTasks.isEmpty) // Host 2 Losts, meaning we lost the map output task4 - manager.executorLost("exec2", "host2", SlaveLost()) + manager.executorLost("exec2", "host2", ExecutorProcessLost()) // Make sure that task with index 2 is re-submitted assert(resubmittedTasks.contains(2)) @@ -1670,7 +1670,7 @@ class TaskSetManagerSuite for (i <- 0 to 99) { locations += Seq(TaskLocation("host" + i)) } - val taskSet = FakeTask.createTaskSet(100, locations: _*) + val taskSet = FakeTask.createTaskSet(100, locations.toSeq: _*) val clock = new ManualClock // make sure we only do one rack resolution call, for the entire batch of hosts, as this // can be expensive. The FakeTaskScheduler calls rack resolution more than the real one diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerInfoSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerInfoSuite.scala index 01e3d6a46e709..3f5ffaa732f25 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerInfoSuite.scala @@ -33,7 +33,7 @@ class BlockManagerInfoSuite extends SparkFunSuite { timeMs = 300, maxOnHeapMem = 10000, maxOffHeapMem = 20000, - slaveEndpoint = null, + storageEndpoint = null, if (svcEnabled) Some(new JHashMap[BlockId, BlockStatus]) else None) test(s"$testName externalShuffleServiceEnabled=$svcEnabled") { f(svcEnabled, bmInfo) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index bfef8f1ab29d8..dc1c7cd52d466 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Future +import scala.concurrent.{Future, TimeoutException} import scala.concurrent.duration._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -49,8 +49,9 @@ import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransport import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExecutorDiskUtils, ExternalBlockStoreClient} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} -import org.apache.spark.rpc.RpcEnv +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerBlockUpdated} +import org.apache.spark.scheduler.cluster.{CoarseGrainedClusterMessages, CoarseGrainedSchedulerBackend} import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager @@ -93,6 +94,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE .set(MEMORY_STORAGE_FRACTION, 0.999) .set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") .set(STORAGE_UNROLL_MEMORY_THRESHOLD, 512L) + .set(Network.RPC_ASK_TIMEOUT, "5s") } private def makeBlockManager( @@ -137,8 +139,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE conf = new SparkConf(false) init(conf) - rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) + rpcEnv = RpcEnv.create("test", conf.get(config.DRIVER_HOST_ADDRESS), + conf.get(config.DRIVER_PORT), conf, securityMgr) conf.set(DRIVER_PORT, rpcEnv.address.port) + conf.set(DRIVER_HOST_ADDRESS, rpcEnv.address.host) // Mock SparkContext to reduce the memory usage of tests. It's fine since the only reason we // need to create a SparkContext is to initialize LiveListenerBus. @@ -177,6 +181,105 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE blockManager.stop() } + /** + * Setup driverEndpoint, executor-1(BlockManager), executor-2(BlockManager) to simulate + * the real cluster before the tests. Any requests from driver to executor-1 will be responded + * in time. However, any requests from driver to executor-2 will be timeouted, in order to test + * the specific handling of `TimeoutException`, which is raised at driver side. + * + * And, when `withLost` is true, we will not register the executor-2 to the driver. Therefore, + * it behaves like a lost executor in terms of driver's view. When `withLost` is false, we'll + * register the executor-2 normally. + */ + private def setupBlockManagerMasterWithBlocks(withLost: Boolean): Unit = { + // set up a simple DriverEndpoint which simply adds executorIds and + // checks whether a certain executorId has been added before. + val driverEndpoint = rpcEnv.setupEndpoint(CoarseGrainedSchedulerBackend.ENDPOINT_NAME, + new RpcEndpoint { + private val executorSet = mutable.HashSet[String]() + override val rpcEnv: RpcEnv = this.rpcEnv + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case CoarseGrainedClusterMessages.RegisterExecutor(executorId, _, _, _, _, _, _, _) => + executorSet += executorId + context.reply(true) + case CoarseGrainedClusterMessages.IsExecutorAlive(executorId) => + context.reply(executorSet.contains(executorId)) + } + } + ) + + def createAndRegisterBlockManager(timeout: Boolean): BlockManagerId = { + val id = if (timeout) "timeout" else "normal" + val bmRef = rpcEnv.setupEndpoint(s"bm-$id", new RpcEndpoint { + override val rpcEnv: RpcEnv = this.rpcEnv + private def reply[T](context: RpcCallContext, response: T): Unit = { + if (timeout) { + Thread.sleep(conf.getTimeAsMs(Network.RPC_ASK_TIMEOUT.key) + 1000) + } + context.reply(response) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RemoveRdd(_) => reply(context, 1) + case RemoveBroadcast(_, _) => reply(context, 1) + case RemoveShuffle(_) => reply(context, true) + } + }) + val bmId = BlockManagerId(s"exec-$id", "localhost", 1234, None) + master.registerBlockManager(bmId, Array.empty, 2000, 0, bmRef) + } + + // set up normal bm1 + val bm1Id = createAndRegisterBlockManager(false) + // set up bm2, which intentionally takes more time than RPC_ASK_TIMEOUT to + // remove rdd/broadcast/shuffle in order to raise timeout error + val bm2Id = createAndRegisterBlockManager(true) + + driverEndpoint.askSync[Boolean](CoarseGrainedClusterMessages.RegisterExecutor( + bm1Id.executorId, null, bm1Id.host, 1, Map.empty, Map.empty, + Map.empty, 0)) + + if (!withLost) { + driverEndpoint.askSync[Boolean](CoarseGrainedClusterMessages.RegisterExecutor( + bm2Id.executorId, null, bm1Id.host, 1, Map.empty, Map.empty, Map.empty, 0)) + } + + eventually(timeout(5.seconds)) { + // make sure both bm1 and bm2 are registered at driver side BlockManagerMaster + verify(master, times(2)) + .registerBlockManager(mc.any(), mc.any(), mc.any(), mc.any(), mc.any()) + assert(driverEndpoint.askSync[Boolean]( + CoarseGrainedClusterMessages.IsExecutorAlive(bm1Id.executorId))) + assert(driverEndpoint.askSync[Boolean]( + CoarseGrainedClusterMessages.IsExecutorAlive(bm2Id.executorId)) === !withLost) + } + + // update RDD block info for bm1 and bm2 (Broadcast and shuffle don't report block + // locations to BlockManagerMaster) + master.updateBlockInfo(bm1Id, RDDBlockId(0, 0), StorageLevel.MEMORY_ONLY, 100, 0) + master.updateBlockInfo(bm2Id, RDDBlockId(0, 1), StorageLevel.MEMORY_ONLY, 100, 0) + } + + test("SPARK-32091: count failures from active executors when remove rdd/broadcast/shuffle") { + setupBlockManagerMasterWithBlocks(false) + // fail because bm2 will timeout and it's not lost anymore + assert(intercept[Exception](master.removeRdd(0, true)) + .getCause.isInstanceOf[TimeoutException]) + assert(intercept[Exception](master.removeBroadcast(0, true, true)) + .getCause.isInstanceOf[TimeoutException]) + assert(intercept[Exception](master.removeShuffle(0, true)) + .getCause.isInstanceOf[TimeoutException]) + } + + test("SPARK-32091: ignore failures from lost executors when remove rdd/broadcast/shuffle") { + setupBlockManagerMasterWithBlocks(true) + // succeed because bm1 will remove rdd/broadcast successfully and bm2 will + // timeout but ignored as it's lost + master.removeRdd(0, true) + master.removeBroadcast(0, true, true) + master.removeShuffle(0, true) + } + test("StorageLevel object caching") { val level1 = StorageLevel(false, false, false, 3) // this should return the same object as level1 @@ -1271,12 +1374,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.master.getLocations("list1").size === 0) assert(store.master.getLocations("list2").size === 1) assert(store.master.getLocations("list3").size === 1) - assert(store.master.getBlockStatus("list1", askSlaves = false).size === 0) - assert(store.master.getBlockStatus("list2", askSlaves = false).size === 1) - assert(store.master.getBlockStatus("list3", askSlaves = false).size === 1) - assert(store.master.getBlockStatus("list1", askSlaves = true).size === 0) - assert(store.master.getBlockStatus("list2", askSlaves = true).size === 1) - assert(store.master.getBlockStatus("list3", askSlaves = true).size === 1) + assert(store.master.getBlockStatus("list1", askStorageEndpoints = false).size === 0) + assert(store.master.getBlockStatus("list2", askStorageEndpoints = false).size === 1) + assert(store.master.getBlockStatus("list3", askStorageEndpoints = false).size === 1) + assert(store.master.getBlockStatus("list1", askStorageEndpoints = true).size === 0) + assert(store.master.getBlockStatus("list2", askStorageEndpoints = true).size === 1) + assert(store.master.getBlockStatus("list3", askStorageEndpoints = true).size === 1) // This time don't tell master and see what happens. By LRU, only list5 and list6 remains. store.putIterator( @@ -1287,17 +1390,17 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE "list6", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) // getLocations should return nothing because the master is not informed - // getBlockStatus without asking slaves should have the same result - // getBlockStatus with asking slaves, however, should return the actual block statuses + // getBlockStatus without asking storage endpoints should have the same result + // getBlockStatus with asking storage endpoints, however, should return the actual statuses assert(store.master.getLocations("list4").size === 0) assert(store.master.getLocations("list5").size === 0) assert(store.master.getLocations("list6").size === 0) - assert(store.master.getBlockStatus("list4", askSlaves = false).size === 0) - assert(store.master.getBlockStatus("list5", askSlaves = false).size === 0) - assert(store.master.getBlockStatus("list6", askSlaves = false).size === 0) - assert(store.master.getBlockStatus("list4", askSlaves = true).size === 0) - assert(store.master.getBlockStatus("list5", askSlaves = true).size === 1) - assert(store.master.getBlockStatus("list6", askSlaves = true).size === 1) + assert(store.master.getBlockStatus("list4", askStorageEndpoints = false).size === 0) + assert(store.master.getBlockStatus("list5", askStorageEndpoints = false).size === 0) + assert(store.master.getBlockStatus("list6", askStorageEndpoints = false).size === 0) + assert(store.master.getBlockStatus("list4", askStorageEndpoints = true).size === 0) + assert(store.master.getBlockStatus("list5", askStorageEndpoints = true).size === 1) + assert(store.master.getBlockStatus("list6", askStorageEndpoints = true).size === 1) } test("get matching blocks") { @@ -1313,9 +1416,11 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE "list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) // getLocations and getBlockStatus should yield the same locations - assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size + assert(store.master.getMatchingBlockIds( + _.toString.contains("list"), askStorageEndpoints = false).size === 3) - assert(store.master.getMatchingBlockIds(_.toString.contains("list1"), askSlaves = false).size + assert(store.master.getMatchingBlockIds( + _.toString.contains("list1"), askStorageEndpoints = false).size === 1) // insert some more blocks @@ -1327,9 +1432,13 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE "newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // getLocations and getBlockStatus should yield the same locations - assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size + assert( + store.master.getMatchingBlockIds( + _.toString.contains("newlist"), askStorageEndpoints = false).size === 1) - assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = true).size + assert( + store.master.getMatchingBlockIds( + _.toString.contains("newlist"), askStorageEndpoints = true).size === 3) val blockIds = Seq(RDDBlockId(1, 0), RDDBlockId(1, 1), RDDBlockId(2, 0)) @@ -1340,7 +1449,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val matchedBlockIds = store.master.getMatchingBlockIds(_ match { case RDDBlockId(1, _) => true case _ => false - }, askSlaves = true) + }, askStorageEndpoints = true) assert(matchedBlockIds.toSet === Set(RDDBlockId(1, 0), RDDBlockId(1, 1))) } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 43917a5b83bb0..bf1379ceb89a8 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -1047,7 +1047,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer() ) - val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0))) + val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0)).toMap) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)) diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index ecfdf481f4f6c..4f808f03e5dab 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -48,24 +48,28 @@ import org.apache.spark.util.CallSite private[spark] class SparkUICssErrorHandler extends DefaultCssErrorHandler { - private val cssWhiteList = List("bootstrap.min.css", "vis-timeline-graph2d.min.css") + /** + * Some libraries have warn/error messages that are too noisy for the tests; exclude them from + * normal error handling to avoid logging these. + */ + private val cssExcludeList = List("bootstrap.min.css", "vis-timeline-graph2d.min.css") - private def isInWhileList(uri: String): Boolean = cssWhiteList.exists(uri.endsWith) + private def isInExcludeList(uri: String): Boolean = cssExcludeList.exists(uri.endsWith) override def warning(e: CSSParseException): Unit = { - if (!isInWhileList(e.getURI)) { + if (!isInExcludeList(e.getURI)) { super.warning(e) } } override def fatalError(e: CSSParseException): Unit = { - if (!isInWhileList(e.getURI)) { + if (!isInExcludeList(e.getURI)) { super.fatalError(e) } } override def error(e: CSSParseException): Unit = { - if (!isInWhileList(e.getURI)) { + if (!isInExcludeList(e.getURI)) { super.error(e) } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 955589fc5b47b..c75e98f39758d 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -334,7 +334,7 @@ class JsonProtocolSuite extends SparkFunSuite { val oldEvent = JsonProtocol.taskEndReasonToJson(fetchFailed) .removeField({ _._1 == "Map Index" }) val expectedFetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 16L, - 0, 19, "ignored") + Int.MinValue, 19, "ignored") assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent)) } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index c9c8ae6023877..7ec7c5afca1df 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1309,6 +1309,103 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.buildLocationMetadata(paths, 15) == "[path0, path1, path2]") assert(Utils.buildLocationMetadata(paths, 25) == "[path0, path1, path2, path3]") } + + test("checkHost supports both IPV4 and IPV6") { + // IPV4 ips + Utils.checkHost("0.0.0.0") + var e: AssertionError = intercept[AssertionError] { + Utils.checkHost("0.0.0.0:0") + } + assert(e.getMessage.contains("Expected hostname or IP but got 0.0.0.0:0")) + e = intercept[AssertionError] { + Utils.checkHost("0.0.0.0:") + } + assert(e.getMessage.contains("Expected hostname or IP but got 0.0.0.0:")) + // IPV6 ips + Utils.checkHost("[::1]") + e = intercept[AssertionError] { + Utils.checkHost("[::1]:0") + } + assert(e.getMessage.contains("Expected hostname or IPv6 IP enclosed in [] but got [::1]:0")) + e = intercept[AssertionError] { + Utils.checkHost("[::1]:") + } + assert(e.getMessage.contains("Expected hostname or IPv6 IP enclosed in [] but got [::1]:")) + // hostname + Utils.checkHost("localhost") + e = intercept[AssertionError] { + Utils.checkHost("localhost:0") + } + assert(e.getMessage.contains("Expected hostname or IP but got localhost:0")) + e = intercept[AssertionError] { + Utils.checkHost("localhost:") + } + assert(e.getMessage.contains("Expected hostname or IP but got localhost:")) + } + + test("checkHostPort support IPV6 and IPV4") { + // IPV4 ips + Utils.checkHostPort("0.0.0.0:0") + var e: AssertionError = intercept[AssertionError] { + Utils.checkHostPort("0.0.0.0") + } + assert(e.getMessage.contains("Expected host and port but got 0.0.0.0")) + + // IPV6 ips + Utils.checkHostPort("[::1]:0") + e = intercept[AssertionError] { + Utils.checkHostPort("[::1]") + } + assert(e.getMessage.contains("Expected host and port but got [::1]")) + + // hostname + Utils.checkHostPort("localhost:0") + e = intercept[AssertionError] { + Utils.checkHostPort("localhost") + } + assert(e.getMessage.contains("Expected host and port but got localhost")) + } + + test("parseHostPort support IPV6 and IPV4") { + // IPV4 ips + var hostnamePort = Utils.parseHostPort("0.0.0.0:80") + assert(hostnamePort._1.equals("0.0.0.0")) + assert(hostnamePort._2 === 80) + + hostnamePort = Utils.parseHostPort("0.0.0.0") + assert(hostnamePort._1.equals("0.0.0.0")) + assert(hostnamePort._2 === 0) + + hostnamePort = Utils.parseHostPort("0.0.0.0:") + assert(hostnamePort._1.equals("0.0.0.0")) + assert(hostnamePort._2 === 0) + + // IPV6 ips + hostnamePort = Utils.parseHostPort("[::1]:80") + assert(hostnamePort._1.equals("[::1]")) + assert(hostnamePort._2 === 80) + + hostnamePort = Utils.parseHostPort("[::1]") + assert(hostnamePort._1.equals("[::1]")) + assert(hostnamePort._2 === 0) + + hostnamePort = Utils.parseHostPort("[::1]:") + assert(hostnamePort._1.equals("[::1]")) + assert(hostnamePort._2 === 0) + + // hostname + hostnamePort = Utils.parseHostPort("localhost:80") + assert(hostnamePort._1.equals("localhost")) + assert(hostnamePort._2 === 80) + + hostnamePort = Utils.parseHostPort("localhost") + assert(hostnamePort._1.equals("localhost")) + assert(hostnamePort._2 === 0) + + hostnamePort = Utils.parseHostPort("localhost:") + assert(hostnamePort._1.equals("localhost")) + assert(hostnamePort._2 === 0) + } } private class SimpleExtension diff --git a/dev/appveyor-install-dependencies.ps1 b/dev/appveyor-install-dependencies.ps1 index 87eb82935e4e0..e344a7fc23191 100644 --- a/dev/appveyor-install-dependencies.ps1 +++ b/dev/appveyor-install-dependencies.ps1 @@ -67,8 +67,8 @@ Function InstallRtools { Else { $gccPath = $env:GCC_PATH } - $env:PATH = $RtoolsDrive + '\Rtools40\bin;' + $RtoolsDrive + '\Rtools40\MinGW$(WIN)\bin;' + $RtoolsDrive + '\Rtools40\' + $gccPath + '\bin;' + $env:PATH - $env:BINPREF=$RtoolsDrive + '/Rtools40/mingw$(WIN)/bin/' + $env:PATH = $RtoolsDrive + '\Rtools40\bin;' + $RtoolsDrive + '\Rtools40\mingw64\bin;' + $RtoolsDrive + '\Rtools40\' + $gccPath + '\bin;' + $env:PATH + $env:BINPREF=$RtoolsDrive + '/Rtools40/mingw64/bin/' } # create tools directory outside of Spark directory @@ -95,22 +95,22 @@ $env:MAVEN_OPTS = "-Xmx2g -XX:ReservedCodeCacheSize=1g" Pop-Location # ========================== Hadoop bin package -# This must match the version at https://github.com/steveloughran/winutils/tree/master/hadoop-2.7.1 -$hadoopVer = "2.7.1" +# This must match the version at https://github.com/cdarlint/winutils/tree/master/hadoop-3.2.0 +$hadoopVer = "3.2.0" $hadoopPath = "$tools\hadoop" if (!(Test-Path $hadoopPath)) { New-Item -ItemType Directory -Force -Path $hadoopPath | Out-Null } Push-Location $hadoopPath -Start-FileDownload "https://github.com/steveloughran/winutils/archive/master.zip" "winutils-master.zip" +Start-FileDownload "https://codeload.github.com/cdarlint/winutils/zip/master" "winutils-master.zip" # extract Invoke-Expression "7z.exe x winutils-master.zip" # add hadoop bin to environment variables -$env:HADOOP_HOME = "$hadoopPath/winutils-master/hadoop-$hadoopVer" -$env:Path += ";$env:HADOOP_HOME\bin" +$env:HADOOP_HOME = "$hadoopPath\winutils-master\hadoop-$hadoopVer" +$env:PATH = "$env:HADOOP_HOME\bin;" + $env:PATH Pop-Location diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index a5a26ae8f5354..241b7ed539ae9 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -49,8 +49,6 @@ print("Install using 'sudo pip install unidecode'") sys.exit(-1) -if sys.version < '3': - input = raw_input # noqa # Contributors list file name contributors_file_name = "contributors.txt" @@ -152,10 +150,7 @@ def get_commits(tag): if not is_valid_author(author): author = github_username # Guard against special characters - try: # Python 2 - author = unicode(author, "UTF-8") - except NameError: # Python 3 - author = str(author) + author = str(author) author = unidecode.unidecode(author).strip() commit = Commit(_hash, author, title, pr_number) commits.append(commit) diff --git a/dev/deps/spark-deps-hadoop-2.7-hive-1.2 b/dev/deps/spark-deps-hadoop-2.7-hive-1.2 index f8a43488d0f7f..344806e447689 100644 --- a/dev/deps/spark-deps-hadoop-2.7-hive-1.2 +++ b/dev/deps/spark-deps-hadoop-2.7-hive-1.2 @@ -208,4 +208,4 @@ xmlenc/0.52//xmlenc-0.52.jar xz/1.5//xz-1.5.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar zookeeper/3.4.14//zookeeper-3.4.14.jar -zstd-jni/1.4.5-2//zstd-jni-1.4.5-2.jar +zstd-jni/1.4.5-4//zstd-jni-1.4.5-4.jar diff --git a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 index a34970b3c9d1d..969249b963e7b 100644 --- a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 @@ -222,4 +222,4 @@ xmlenc/0.52//xmlenc-0.52.jar xz/1.5//xz-1.5.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar zookeeper/3.4.14//zookeeper-3.4.14.jar -zstd-jni/1.4.5-2//zstd-jni-1.4.5-2.jar +zstd-jni/1.4.5-4//zstd-jni-1.4.5-4.jar diff --git a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 index c8fade45739c0..e98e4676107ed 100644 --- a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 @@ -236,4 +236,4 @@ xbean-asm7-shaded/4.15//xbean-asm7-shaded-4.15.jar xz/1.5//xz-1.5.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar zookeeper/3.4.14//zookeeper-3.4.14.jar -zstd-jni/1.4.5-2//zstd-jni-1.4.5-2.jar +zstd-jni/1.4.5-4//zstd-jni-1.4.5-4.jar diff --git a/dev/github_jira_sync.py b/dev/github_jira_sync.py index b444b74d4027c..b90afeebc5238 100755 --- a/dev/github_jira_sync.py +++ b/dev/github_jira_sync.py @@ -22,14 +22,9 @@ import os import re import sys -if sys.version < '3': - from urllib2 import urlopen - from urllib2 import Request - from urllib2 import HTTPError -else: - from urllib.request import urlopen - from urllib.request import Request - from urllib.error import HTTPError +from urllib.request import urlopen +from urllib.request import Request +from urllib.error import HTTPError try: import jira.client diff --git a/dev/lint-python b/dev/lint-python index d5491f2447176..1fddbfa64b32c 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -168,7 +168,15 @@ function sphinx_test { # Check that the documentation builds acceptably, skip check if sphinx is not installed. if ! hash "$SPHINX_BUILD" 2> /dev/null; then - echo "The $SPHINX_BUILD command was not found. Skipping pydoc checks for now." + echo "The $SPHINX_BUILD command was not found. Skipping Sphinx build for now." + echo + return + fi + + # TODO(SPARK-32279): Install Sphinx in Python 3 of Jenkins machines + PYTHON_HAS_SPHINX=$("$PYTHON_EXECUTABLE" -c 'import importlib.util; print(importlib.util.find_spec("sphinx") is not None)') + if [[ "$PYTHON_HAS_SPHINX" == "False" ]]; then + echo "$PYTHON_EXECUTABLE does not have Sphinx installed. Skipping Sphinx build for now." echo return fi diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 967cdace60dc9..b42429d7175b1 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -31,15 +31,9 @@ import subprocess import sys import traceback -if sys.version < '3': - input = raw_input # noqa - from urllib2 import urlopen - from urllib2 import Request - from urllib2 import HTTPError -else: - from urllib.request import urlopen - from urllib.request import Request - from urllib.error import HTTPError +from urllib.request import urlopen +from urllib.request import Request +from urllib.error import HTTPError try: import jira.client diff --git a/dev/run-pip-tests b/dev/run-pip-tests index 470f21e69d46a..5fd0be7476f29 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -63,7 +63,7 @@ fi PYSPARK_VERSION=$(python3 -c "exec(open('python/pyspark/version.py').read());print(__version__)") PYSPARK_DIST="$FWDIR/python/dist/pyspark-$PYSPARK_VERSION.tar.gz" # The pip install options we use for all the pip commands -PIP_OPTIONS="--upgrade --no-cache-dir --force-reinstall " +PIP_OPTIONS="--user --upgrade --no-cache-dir --force-reinstall " # Test both regular user and edit/dev install modes. PIP_COMMANDS=("pip install $PIP_OPTIONS $PYSPARK_DIST" "pip install $PIP_OPTIONS -e python/") @@ -76,8 +76,12 @@ for python in "${PYTHON_EXECS[@]}"; do VIRTUALENV_PATH="$VIRTUALENV_BASE"/$python rm -rf "$VIRTUALENV_PATH" if [ -n "$USE_CONDA" ]; then + if [ -f "$CONDA_PREFIX/etc/profile.d/conda.sh" ]; then + # See also https://github.com/conda/conda/issues/7980 + source "$CONDA_PREFIX/etc/profile.d/conda.sh" + fi conda create -y -p "$VIRTUALENV_PATH" python=$python numpy pandas pip setuptools - source activate "$VIRTUALENV_PATH" + conda activate "$VIRTUALENV_PATH" || (echo "Falling back to 'source activate'" && source activate "$VIRTUALENV_PATH") else mkdir -p "$VIRTUALENV_PATH" virtualenv --python=$python "$VIRTUALENV_PATH" @@ -92,6 +96,8 @@ for python in "${PYTHON_EXECS[@]}"; do cd "$FWDIR"/python # Delete the egg info file if it exists, this can cache the setup file. rm -rf pyspark.egg-info || echo "No existing egg info file, skipping deletion" + # Also, delete the symbolic link if exists. It can be left over from the previous editable mode installation. + python3 -c "from distutils.sysconfig import get_python_lib; import os; f = os.path.join(get_python_lib(), 'pyspark.egg-link'); os.unlink(f) if os.path.isfile(f) else 0" python3 setup.py sdist @@ -110,6 +116,7 @@ for python in "${PYTHON_EXECS[@]}"; do cd / echo "Run basic sanity check on pip installed version with spark-submit" + export PATH="$(python3 -m site --user-base)/bin:$PATH" spark-submit "$FWDIR"/dev/pip-sanity-check.py echo "Run basic sanity check with import based" python3 "$FWDIR"/dev/pip-sanity-check.py @@ -120,7 +127,7 @@ for python in "${PYTHON_EXECS[@]}"; do # conda / virtualenv environments need to be deactivated differently if [ -n "$USE_CONDA" ]; then - source deactivate + conda deactivate || (echo "Falling back to 'source deactivate'" && source deactivate) else deactivate fi diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 13be9592d771f..4ff5b327e3325 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -22,15 +22,9 @@ import json import functools import subprocess -if sys.version < '3': - from urllib2 import urlopen - from urllib2 import Request - from urllib2 import HTTPError, URLError -else: - from urllib.request import urlopen - from urllib.request import Request - from urllib.error import HTTPError, URLError - +from urllib.request import urlopen +from urllib.request import Request +from urllib.error import HTTPError, URLError from sparktestsupport import SPARK_HOME, ERROR_CODES from sparktestsupport.shellutils import run_cmd diff --git a/dev/run-tests.py b/dev/run-tests.py index ca502b2818847..065a27c0e853b 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -101,28 +101,52 @@ def setup_test_environ(environ): os.environ[k] = v -def determine_modules_to_test(changed_modules): +def determine_modules_to_test(changed_modules, deduplicated=True): """ Given a set of modules that have changed, compute the transitive closure of those modules' dependent modules in order to determine the set of modules that should be tested. Returns a topologically-sorted list of modules (ties are broken by sorting on module names). + If ``deduplicated`` is disabled, the modules are returned without tacking the deduplication + by dependencies into account. >>> [x.name for x in determine_modules_to_test([modules.root])] ['root'] >>> [x.name for x in determine_modules_to_test([modules.build])] ['root'] + >>> [x.name for x in determine_modules_to_test([modules.core])] + ['root'] + >>> [x.name for x in determine_modules_to_test([modules.launcher])] + ['root'] >>> [x.name for x in determine_modules_to_test([modules.graphx])] ['graphx', 'examples'] - >>> x = [x.name for x in determine_modules_to_test([modules.sql])] - >>> x # doctest: +NORMALIZE_WHITESPACE + >>> [x.name for x in determine_modules_to_test([modules.sql])] + ... # doctest: +NORMALIZE_WHITESPACE ['sql', 'avro', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', 'pyspark-sql', 'repl', 'sparkr', 'pyspark-mllib', 'pyspark-ml'] + >>> sorted([x.name for x in determine_modules_to_test( + ... [modules.sparkr, modules.sql], deduplicated=False)]) + ... # doctest: +NORMALIZE_WHITESPACE + ['avro', 'examples', 'hive', 'hive-thriftserver', 'mllib', 'pyspark-ml', + 'pyspark-mllib', 'pyspark-sql', 'repl', 'sparkr', 'sql', 'sql-kafka-0-10'] + >>> sorted([x.name for x in determine_modules_to_test( + ... [modules.sql, modules.core], deduplicated=False)]) + ... # doctest: +NORMALIZE_WHITESPACE + ['avro', 'catalyst', 'core', 'examples', 'graphx', 'hive', 'hive-thriftserver', + 'mllib', 'mllib-local', 'pyspark-core', 'pyspark-ml', 'pyspark-mllib', + 'pyspark-resource', 'pyspark-sql', 'pyspark-streaming', 'repl', 'root', + 'sparkr', 'sql', 'sql-kafka-0-10', 'streaming', 'streaming-kafka-0-10', + 'streaming-kinesis-asl'] """ modules_to_test = set() for module in changed_modules: - modules_to_test = modules_to_test.union(determine_modules_to_test(module.dependent_modules)) + modules_to_test = modules_to_test.union( + determine_modules_to_test(module.dependent_modules, deduplicated)) modules_to_test = modules_to_test.union(set(changed_modules)) + + if not deduplicated: + return modules_to_test + # If we need to run all of the tests, then we should short-circuit and return 'root' if modules.root in modules_to_test: return [modules.root] @@ -363,7 +387,8 @@ def build_spark_assembly_sbt(extra_profiles, checkstyle=False): if checkstyle: run_java_style_checks(build_profiles) - build_spark_unidoc_sbt(extra_profiles) + if not os.environ.get("AMPLAB_JENKINS"): + build_spark_unidoc_sbt(extra_profiles) def build_apache_spark(build_tool, extra_profiles): @@ -415,7 +440,7 @@ def run_scala_tests_sbt(test_modules, test_profiles): exec_sbt(profiles_and_goals) -def run_scala_tests(build_tool, extra_profiles, test_modules, excluded_tags): +def run_scala_tests(build_tool, extra_profiles, test_modules, excluded_tags, included_tags): """Function to properly execute all tests passed in as a set from the `determine_test_suites` function""" set_title_and_block("Running Spark unit tests", "BLOCK_SPARK_UNIT_TESTS") @@ -425,6 +450,8 @@ def run_scala_tests(build_tool, extra_profiles, test_modules, excluded_tags): test_profiles = extra_profiles + \ list(set(itertools.chain.from_iterable(m.build_profile_flags for m in test_modules))) + if included_tags: + test_profiles += ['-Dtest.include.tags=' + ",".join(included_tags)] if excluded_tags: test_profiles += ['-Dtest.exclude.tags=' + ",".join(excluded_tags)] @@ -532,6 +559,24 @@ def parse_opts(): "-p", "--parallelism", type=int, default=8, help="The number of suites to test in parallel (default %(default)d)" ) + parser.add_argument( + "-m", "--modules", type=str, + default=None, + help="A comma-separated list of modules to test " + "(default: %s)" % ",".join(sorted([m.name for m in modules.all_modules])) + ) + parser.add_argument( + "-e", "--excluded-tags", type=str, + default=None, + help="A comma-separated list of tags to exclude in the tests, " + "e.g., org.apache.spark.tags.ExtendedHiveTest " + ) + parser.add_argument( + "-i", "--included-tags", type=str, + default=None, + help="A comma-separated list of tags to include in the tests, " + "e.g., org.apache.spark.tags.ExtendedHiveTest " + ) args, unknown = parser.parse_known_args() if unknown: @@ -582,27 +627,74 @@ def main(): # /home/jenkins/anaconda2/envs/py36/bin os.environ["PATH"] = "/home/anaconda/envs/py36/bin:" + os.environ.get("PATH") else: - # else we're running locally and can use local settings + # else we're running locally or Github Actions. build_tool = "sbt" hadoop_version = os.environ.get("HADOOP_PROFILE", "hadoop2.7") hive_version = os.environ.get("HIVE_PROFILE", "hive2.3") - test_env = "local" + if "GITHUB_ACTIONS" in os.environ: + test_env = "github_actions" + else: + test_env = "local" print("[info] Using build tool", build_tool, "with Hadoop profile", hadoop_version, "and Hive profile", hive_version, "under environment", test_env) extra_profiles = get_hadoop_profiles(hadoop_version) + get_hive_profiles(hive_version) changed_modules = None + test_modules = None changed_files = None - if test_env == "amplab_jenkins" and os.environ.get("AMP_JENKINS_PRB"): + should_only_test_modules = opts.modules is not None + included_tags = [] + excluded_tags = [] + if should_only_test_modules: + str_test_modules = [m.strip() for m in opts.modules.split(",")] + test_modules = [m for m in modules.all_modules if m.name in str_test_modules] + + # If we're running the tests in Github Actions, attempt to detect and test + # only the affected modules. + if test_env == "github_actions": + if os.environ["GITHUB_BASE_REF"] != "": + # Pull requests + changed_files = identify_changed_files_from_git_commits( + os.environ["GITHUB_SHA"], target_branch=os.environ["GITHUB_BASE_REF"]) + else: + # Build for each commit. + changed_files = identify_changed_files_from_git_commits( + os.environ["GITHUB_SHA"], target_ref=os.environ["GITHUB_PREV_SHA"]) + + modules_to_test = determine_modules_to_test( + determine_modules_for_files(changed_files), deduplicated=False) + + if modules.root not in modules_to_test: + # If root module is not found, only test the intersected modules. + # If root module is found, just run the modules as specified initially. + test_modules = list(set(modules_to_test).intersection(test_modules)) + + changed_modules = test_modules + if len(changed_modules) == 0: + print("[info] There are no modules to test, exiting without testing.") + return + + # If we're running the tests in AMPLab Jenkins, calculate the diff from the targeted branch, and + # detect modules to test. + elif test_env == "amplab_jenkins" and os.environ.get("AMP_JENKINS_PRB"): target_branch = os.environ["ghprbTargetBranch"] changed_files = identify_changed_files_from_git_commits("HEAD", target_branch=target_branch) changed_modules = determine_modules_for_files(changed_files) + test_modules = determine_modules_to_test(changed_modules) excluded_tags = determine_tags_to_exclude(changed_modules) + # If there is no changed module found, tests all. if not changed_modules: changed_modules = [modules.root] - excluded_tags = [] + if not test_modules: + test_modules = determine_modules_to_test(changed_modules) + + if opts.excluded_tags: + excluded_tags.extend([t.strip() for t in opts.excluded_tags.split(",")]) + if opts.included_tags: + included_tags.extend([t.strip() for t in opts.included_tags.split(",")]) + print("[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules)) @@ -615,33 +707,32 @@ def main(): test_environ.update(m.environ) setup_test_environ(test_environ) - test_modules = determine_modules_to_test(changed_modules) - - # license checks - run_apache_rat_checks() - - # style checks - if not changed_files or any(f.endswith(".scala") - or f.endswith("scalastyle-config.xml") - for f in changed_files): - run_scala_style_checks(extra_profiles) should_run_java_style_checks = False - if not changed_files or any(f.endswith(".java") - or f.endswith("checkstyle.xml") - or f.endswith("checkstyle-suppressions.xml") - for f in changed_files): - # Run SBT Checkstyle after the build to prevent a side-effect to the build. - should_run_java_style_checks = True - if not changed_files or any(f.endswith("lint-python") - or f.endswith("tox.ini") - or f.endswith(".py") - for f in changed_files): - run_python_style_checks() - if not changed_files or any(f.endswith(".R") - or f.endswith("lint-r") - or f.endswith(".lintr") - for f in changed_files): - run_sparkr_style_checks() + if not should_only_test_modules: + # license checks + run_apache_rat_checks() + + # style checks + if not changed_files or any(f.endswith(".scala") + or f.endswith("scalastyle-config.xml") + for f in changed_files): + run_scala_style_checks(extra_profiles) + if not changed_files or any(f.endswith(".java") + or f.endswith("checkstyle.xml") + or f.endswith("checkstyle-suppressions.xml") + for f in changed_files): + # Run SBT Checkstyle after the build to prevent a side-effect to the build. + should_run_java_style_checks = True + if not changed_files or any(f.endswith("lint-python") + or f.endswith("tox.ini") + or f.endswith(".py") + for f in changed_files): + run_python_style_checks() + if not changed_files or any(f.endswith(".R") + or f.endswith("lint-r") + or f.endswith(".lintr") + for f in changed_files): + run_sparkr_style_checks() # determine if docs were changed and if we're inside the amplab environment # note - the below commented out until *all* Jenkins workers can get `jekyll` installed @@ -663,7 +754,7 @@ def main(): build_spark_assembly_sbt(extra_profiles, should_run_java_style_checks) # run the test suites - run_scala_tests(build_tool, extra_profiles, test_modules, excluded_tags) + run_scala_tests(build_tool, extra_profiles, test_modules, excluded_tags, included_tags) modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 85e6a1e9fadac..3c438e309c22d 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -32,7 +32,7 @@ class Module(object): """ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, - sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), + sbt_test_goals=(), python_test_goals=(), excluded_python_implementations=(), test_tags=(), should_run_r_tests=False, should_run_build_tests=False): """ Define a new module. @@ -49,7 +49,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= module are changed. :param sbt_test_goals: A set of SBT test goals for testing this module. :param python_test_goals: A set of Python test goals for testing this module. - :param blacklisted_python_implementations: A set of Python implementations that are not + :param excluded_python_implementations: A set of Python implementations that are not supported by this module's Python components. The values in this set should match strings returned by Python's `platform.python_implementation()`. :param test_tags A set of tags that will be excluded when running unit tests if the module @@ -64,7 +64,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.build_profile_flags = build_profile_flags self.environ = environ self.python_test_goals = python_test_goals - self.blacklisted_python_implementations = blacklisted_python_implementations + self.excluded_python_implementations = excluded_python_implementations self.test_tags = test_tags self.should_run_r_tests = should_run_r_tests self.should_run_build_tests = should_run_build_tests @@ -100,9 +100,75 @@ def __hash__(self): ] ) +kvstore = Module( + name="kvstore", + dependencies=[tags], + source_file_regexes=[ + "common/kvstore/", + ], + sbt_test_goals=[ + "kvstore/test", + ], +) + +network_common = Module( + name="network-common", + dependencies=[tags], + source_file_regexes=[ + "common/network-common/", + ], + sbt_test_goals=[ + "network-common/test", + ], +) + +network_shuffle = Module( + name="network-shuffle", + dependencies=[tags], + source_file_regexes=[ + "common/network-shuffle/", + ], + sbt_test_goals=[ + "network-shuffle/test", + ], +) + +unsafe = Module( + name="unsafe", + dependencies=[tags], + source_file_regexes=[ + "common/unsafe", + ], + sbt_test_goals=[ + "unsafe/test", + ], +) + +launcher = Module( + name="launcher", + dependencies=[tags], + source_file_regexes=[ + "launcher/", + ], + sbt_test_goals=[ + "launcher/test", + ], +) + +core = Module( + name="core", + dependencies=[kvstore, network_common, network_shuffle, unsafe, launcher], + source_file_regexes=[ + "core/", + ], + sbt_test_goals=[ + "core/test", + ], +) + catalyst = Module( name="catalyst", - dependencies=[tags], + dependencies=[tags, core], source_file_regexes=[ "sql/catalyst/", ], @@ -111,7 +177,6 @@ def __hash__(self): ], ) - sql = Module( name="sql", dependencies=[catalyst], @@ -123,7 +188,6 @@ def __hash__(self): ], ) - hive = Module( name="hive", dependencies=[sql], @@ -142,7 +206,6 @@ def __hash__(self): ] ) - repl = Module( name="repl", dependencies=[hive], @@ -154,7 +217,6 @@ def __hash__(self): ], ) - hive_thriftserver = Module( name="hive-thriftserver", dependencies=[hive], @@ -192,7 +254,6 @@ def __hash__(self): ] ) - sketch = Module( name="sketch", dependencies=[tags], @@ -204,10 +265,9 @@ def __hash__(self): ] ) - graphx = Module( name="graphx", - dependencies=[tags], + dependencies=[tags, core], source_file_regexes=[ "graphx/", ], @@ -216,10 +276,9 @@ def __hash__(self): ] ) - streaming = Module( name="streaming", - dependencies=[tags], + dependencies=[tags, core], source_file_regexes=[ "streaming", ], @@ -235,7 +294,7 @@ def __hash__(self): # fail other PRs. streaming_kinesis_asl = Module( name="streaming-kinesis-asl", - dependencies=[tags], + dependencies=[tags, core], source_file_regexes=[ "external/kinesis-asl/", "external/kinesis-asl-assembly/", @@ -254,21 +313,23 @@ def __hash__(self): streaming_kafka_0_10 = Module( name="streaming-kafka-0-10", - dependencies=[streaming], + dependencies=[streaming, core], source_file_regexes=[ # The ending "/" is necessary otherwise it will include "sql-kafka" codes "external/kafka-0-10/", "external/kafka-0-10-assembly", + "external/kafka-0-10-token-provider", ], sbt_test_goals=[ "streaming-kafka-0-10/test", + "token-provider-kafka-0-10/test" ] ) mllib_local = Module( name="mllib-local", - dependencies=[tags], + dependencies=[tags, core], source_file_regexes=[ "mllib-local", ], @@ -302,10 +363,9 @@ def __hash__(self): ] ) - pyspark_core = Module( name="pyspark-core", - dependencies=[], + dependencies=[core], source_file_regexes=[ "python/(?!pyspark/(ml|mllib|sql|streaming))" ], @@ -339,7 +399,6 @@ def __hash__(self): ] ) - pyspark_sql = Module( name="pyspark-sql", dependencies=[pyspark_core, hive, avro], @@ -465,7 +524,7 @@ def __hash__(self): "pyspark.mllib.tests.test_streaming_algorithms", "pyspark.mllib.tests.test_util", ], - blacklisted_python_implementations=[ + excluded_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there ] ) @@ -506,7 +565,7 @@ def __hash__(self): "pyspark.ml.tests.test_tuning", "pyspark.ml.tests.test_wrapper", ], - blacklisted_python_implementations=[ + excluded_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there ] ) @@ -593,7 +652,7 @@ def __hash__(self): # No other modules should directly depend on this module. root = Module( name="root", - dependencies=[build], # Changes to build should trigger all tests. + dependencies=[build, core], # Changes to build should trigger all tests. source_file_regexes=[], # In order to run all of the tests, enable every test profile: build_profile_flags=list(set( diff --git a/dev/sparktestsupport/toposort.py b/dev/sparktestsupport/toposort.py index 8b2688d20039f..6785e481b56b5 100644 --- a/dev/sparktestsupport/toposort.py +++ b/dev/sparktestsupport/toposort.py @@ -24,8 +24,7 @@ # Moved functools import to the top of the file. # Changed assert to a ValueError. # Changed iter[items|keys] to [items|keys], for python 3 -# compatibility. I don't think it matters for python 2 these are -# now lists instead of iterables. +# compatibility. # Copy the input so as to leave it unmodified. # Renamed function from toposort2 to toposort. # Handle empty input. diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index 01f437f38ef17..749d026528017 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -103,7 +103,7 @@ for talking to cloud infrastructures, in which case this module may not be neede Spark jobs must authenticate with the object stores to access data within them. 1. When Spark is running in a cloud infrastructure, the credentials are usually automatically set up. -1. `spark-submit` reads the `AWS_ACCESS_KEY`, `AWS_SECRET_KEY` +1. `spark-submit` reads the `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY` and `AWS_SESSION_TOKEN` environment variables and sets the associated authentication options for the `s3n` and `s3a` connectors to Amazon S3. 1. In a Hadoop cluster, settings may be set in the `core-site.xml` file. diff --git a/docs/configuration.md b/docs/configuration.md index 706c2552b1d17..abf76105ae77d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1890,7 +1890,7 @@ Apart from these, the following properties are also available, and may be useful @@ -2917,7 +2917,7 @@ The following variables can be set in `spark-env.sh`: - diff --git a/docs/index.md b/docs/index.md index c0771ca170af5..8fd169e63f608 100644 --- a/docs/index.md +++ b/docs/index.md @@ -44,9 +44,8 @@ source, visit [Building Spark](building-spark.html). Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS), and it should run on any platform that runs a supported version of Java. This should include JVMs on x86_64 and ARM64. It's easy to run locally on one machine --- all you need is to have `java` installed on your system `PATH`, or the `JAVA_HOME` environment variable pointing to a Java installation. -Spark runs on Java 8/11, Scala 2.12, Python 2.7+/3.4+ and R 3.5+. +Spark runs on Java 8/11, Scala 2.12, Python 3.6+ and R 3.5+. Java 8 prior to version 8u92 support is deprecated as of Spark 3.0.0. -Python 2 and Python 3 prior to version 3.6 support is deprecated as of Spark 3.0.0. For the Scala API, Spark {{site.SPARK_VERSION}} uses Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version ({{site.SCALA_BINARY_VERSION}}.x). diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index eaacfa49c657c..5c19c77f37a81 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -95,7 +95,7 @@ varies across cluster managers: In standalone mode, simply start your workers with `spark.shuffle.service.enabled` set to `true`. In Mesos coarse-grained mode, run `$SPARK_HOME/sbin/start-mesos-shuffle-service.sh` on all -slave nodes with `spark.shuffle.service.enabled` set to `true`. For instance, you may do so +worker nodes with `spark.shuffle.service.enabled` set to `true`. For instance, you may do so through Marathon. In YARN mode, follow the instructions [here](running-on-yarn.html#configuring-the-external-shuffle-service). diff --git a/docs/ml-datasource.md b/docs/ml-datasource.md index 0f2f5f482ec50..8e9c947b75f38 100644 --- a/docs/ml-datasource.md +++ b/docs/ml-datasource.md @@ -86,7 +86,7 @@ Will output: In PySpark we provide Spark SQL data source API for loading image data as a DataFrame. {% highlight python %} ->>> df = spark.read.format("image").option("dropInvalid", true).load("data/mllib/images/origin/kittens") +>>> df = spark.read.format("image").option("dropInvalid", True).load("data/mllib/images/origin/kittens") >>> df.select("image.origin", "image.width", "image.height").show(truncate=False) +-----------------------------------------------------------------------+-----+------+ |origin |width|height| diff --git a/docs/monitoring.md b/docs/monitoring.md index 32959b77c4773..2ab7b30a1dca9 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -373,6 +373,25 @@ Security options for the Spark History Server are covered more detail in the + + + + + + + + + + + +
Default timeout for all network interactions. This config will be used in place of spark.core.connection.ack.wait.timeout, - spark.storage.blockManagerSlaveTimeoutMs, + spark.storage.blockManagerHeartbeatTimeoutMs, spark.shuffle.io.connectionTimeout, spark.rpc.askTimeout or spark.rpc.lookupTimeout if they are not configured.
PYSPARK_PYTHONPython binary executable to use for PySpark in both driver and workers (default is python2.7 if available, otherwise python). + Python binary executable to use for PySpark in both driver and workers (default is python3 if available, otherwise python). Property spark.pyspark.python take precedence if it is set
3.0.0
spark.history.store.hybridStore.enabledfalse + Whether to use HybridStore as the store when parsing event logs. HybridStore will first write data + to an in-memory store and having a background thread that dumps data to a disk store after the writing + to in-memory store is completed. + 3.1.0
spark.history.store.hybridStore.maxMemoryUsage2g + Maximum memory space that can be used to create HybridStore. The HybridStore co-uses the heap memory, + so the heap memory should be increased through the memory option for SHS if the HybridStore is enabled. + 3.1.0
Note that in all of these UIs, the tables are sortable by clicking their headers, @@ -480,7 +499,8 @@ can be identified by their `[attempt-id]`. In the API listed below, when running A list of all tasks for the given stage attempt.
?offset=[offset]&length=[len] list tasks in the given range.
?sortBy=[runtime|-runtime] sort the tasks. -
Example: ?offset=10&length=50&sortBy=runtime +
?status=[running|success|killed|failed|unknown] list only tasks in the state. +
Example: ?offset=10&length=50&sortBy=runtime&status=running diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index 70bfefce475a1..07207f62bb9bd 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -101,10 +101,10 @@ import org.apache.spark.SparkConf;
-Spark {{site.SPARK_VERSION}} works with Python 2.7+ or Python 3.4+. It can use the standard CPython interpreter, +Spark {{site.SPARK_VERSION}} works with Python 3.6+. It can use the standard CPython interpreter, so C libraries like NumPy can be used. It also works with PyPy 2.3+. -Note that Python 2 support is deprecated as of Spark 3.0.0. +Python 2, 3.4 and 3.5 supports were removed in Spark 3.1.0. Spark applications in Python can either be run with the `bin/spark-submit` script which includes Spark at runtime, or by including it in your setup.py as: @@ -134,8 +134,8 @@ PySpark requires the same minor version of Python in both driver and workers. It you can specify which version of Python you want to use by `PYSPARK_PYTHON`, for example: {% highlight bash %} -$ PYSPARK_PYTHON=python3.4 bin/pyspark -$ PYSPARK_PYTHON=/opt/pypy-2.5/bin/pypy bin/spark-submit examples/src/main/python/pi.py +$ PYSPARK_PYTHON=python3.8 bin/pyspark +$ PYSPARK_PYTHON=/path-to-your-pypy/pypy bin/spark-submit examples/src/main/python/pi.py {% endhighlight %}
@@ -276,7 +276,7 @@ $ PYSPARK_DRIVER_PYTHON=jupyter PYSPARK_DRIVER_PYTHON_OPTS=notebook ./bin/pyspar You can customize the `ipython` or `jupyter` commands by setting `PYSPARK_DRIVER_PYTHON_OPTS`. -After the Jupyter Notebook server is launched, you can create a new "Python 2" notebook from +After the Jupyter Notebook server is launched, you can create a new notebook from the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of your notebook before you start to try Spark from the Jupyter notebook. @@ -447,7 +447,7 @@ Writables are automatically converted: - + diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 6f6ae1c0ff264..578ab90fedfca 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -91,7 +91,7 @@ but Mesos can be run without ZooKeeper using a single master as well. ## Verification To verify that the Mesos cluster is ready for Spark, navigate to the Mesos master webui at port -`:5050` Confirm that all expected machines are present in the slaves tab. +`:5050` Confirm that all expected machines are present in the agents tab. # Connecting Spark to Mesos @@ -99,7 +99,7 @@ To verify that the Mesos cluster is ready for Spark, navigate to the Mesos maste To use Mesos from Spark, you need a Spark binary package available in a place accessible by Mesos, and a Spark driver program configured to connect to Mesos. -Alternatively, you can also install Spark in the same location in all the Mesos slaves, and configure +Alternatively, you can also install Spark in the same location in all the Mesos agents, and configure `spark.mesos.executor.home` (defaults to SPARK_HOME) to point to that location. ## Authenticating to Mesos @@ -138,7 +138,7 @@ Then submit happens as described in Client mode or Cluster mode below ## Uploading Spark Package -When Mesos runs a task on a Mesos slave for the first time, that slave must have a Spark binary +When Mesos runs a task on a Mesos agent for the first time, that agent must have a Spark binary package for running the Spark Mesos executor backend. The Spark package can be hosted at any Hadoop-accessible URI, including HTTP via `http://`, [Amazon Simple Storage Service](http://aws.amazon.com/s3) via `s3n://`, or HDFS via `hdfs://`. @@ -237,7 +237,7 @@ For example: {% endhighlight %} -Note that jars or python files that are passed to spark-submit should be URIs reachable by Mesos slaves, as the Spark driver doesn't automatically upload local jars. +Note that jars or python files that are passed to spark-submit should be URIs reachable by Mesos agents, as the Spark driver doesn't automatically upload local jars. # Mesos Run Modes @@ -360,7 +360,7 @@ see [Dynamic Resource Allocation](job-scheduling.html#dynamic-resource-allocatio The External Shuffle Service to use is the Mesos Shuffle Service. It provides shuffle data cleanup functionality on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's -termination. To launch it, run `$SPARK_HOME/sbin/start-mesos-shuffle-service.sh` on all slave nodes, with `spark.shuffle.service.enabled` set to `true`. +termination. To launch it, run `$SPARK_HOME/sbin/start-mesos-shuffle-service.sh` on all agent nodes, with `spark.shuffle.service.enabled` set to `true`. This can also be achieved through Marathon, using a unique host constraint, and the following command: `./bin/spark-class org.apache.spark.deploy.mesos.MesosExternalShuffleService`. @@ -840,17 +840,17 @@ See the [configuration page](configuration.html) for information on Spark config A few places to look during debugging: - Mesos master on port `:5050` - - Slaves should appear in the slaves tab + - Agents should appear in the agents tab - Spark applications should appear in the frameworks tab - Tasks should appear in the details of a framework - Check the stdout and stderr of the sandbox of failed tasks - Mesos logs - - Master and slave logs are both in `/var/log/mesos` by default + - Master and agent logs are both in `/var/log/mesos` by default And common pitfalls: - Spark assembly not reachable/accessible - - Slaves must be able to download the Spark binary package from the `http://`, `hdfs://` or `s3n://` URL you gave + - Agents must be able to download the Spark binary package from the `http://`, `hdfs://` or `s3n://` URL you gave - Firewall blocking communications - Check for messages about failed connections - Temporarily disable firewalls for debugging and then poke appropriate holes diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index f3c479ba26547..4344893fd3584 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -44,7 +44,7 @@ the master's web UI, which is [http://localhost:8080](http://localhost:8080) by Similarly, you can start one or more workers and connect them to the master via: - ./sbin/start-slave.sh + ./sbin/start-worker.sh Once you have started a worker, look at the master's web UI ([http://localhost:8080](http://localhost:8080) by default). You should see the new node listed there, along with its number of CPUs and memory (minus one gigabyte left for the OS). @@ -90,9 +90,9 @@ Finally, the following configuration options can be passed to the master and wor # Cluster Launch Scripts -To launch a Spark standalone cluster with the launch scripts, you should create a file called conf/slaves in your Spark directory, +To launch a Spark standalone cluster with the launch scripts, you should create a file called conf/workers in your Spark directory, which must contain the hostnames of all the machines where you intend to start Spark workers, one per line. -If conf/slaves does not exist, the launch scripts defaults to a single machine (localhost), which is useful for testing. +If conf/workers does not exist, the launch scripts defaults to a single machine (localhost), which is useful for testing. Note, the master machine accesses each of the worker machines via ssh. By default, ssh is run in parallel and requires password-less (using a private key) access to be setup. If you do not have a password-less setup, you can set the environment variable SPARK_SSH_FOREGROUND and serially provide a password for each worker. @@ -100,12 +100,12 @@ If you do not have a password-less setup, you can set the environment variable S Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/sbin`: - `sbin/start-master.sh` - Starts a master instance on the machine the script is executed on. -- `sbin/start-slaves.sh` - Starts a worker instance on each machine specified in the `conf/slaves` file. -- `sbin/start-slave.sh` - Starts a worker instance on the machine the script is executed on. +- `sbin/start-workers.sh` - Starts a worker instance on each machine specified in the `conf/workers` file. +- `sbin/start-worker.sh` - Starts a worker instance on the machine the script is executed on. - `sbin/start-all.sh` - Starts both a master and a number of workers as described above. - `sbin/stop-master.sh` - Stops the master that was started via the `sbin/start-master.sh` script. -- `sbin/stop-slave.sh` - Stops all worker instances on the machine the script is executed on. -- `sbin/stop-slaves.sh` - Stops all worker instances on the machines specified in the `conf/slaves` file. +- `sbin/stop-worker.sh` - Stops all worker instances on the machine the script is executed on. +- `sbin/stop-workers.sh` - Stops all worker instances on the machines specified in the `conf/workers` file. - `sbin/stop-all.sh` - Stops both the master and the workers as described above. Note that these scripts must be executed on the machine you want to run the Spark master on, not your local machine. @@ -457,7 +457,7 @@ worker during one single schedule iteration. Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default, you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options. -In addition, detailed log output for each job is also written to the work directory of each slave node (`SPARK_HOME/work` by default). You will see two files for each job, `stdout` and `stderr`, with all output it wrote to its console. +In addition, detailed log output for each job is also written to the work directory of each worker node (`SPARK_HOME/work` by default). You will see two files for each job, `stdout` and `stderr`, with all output it wrote to its console. # Running Alongside Hadoop diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 0c84db38afafc..d3138ae319160 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -30,8 +30,6 @@ license: | - In Spark 3.1, `from_unixtime`, `unix_timestamp`,`to_unix_timestamp`, `to_timestamp` and `to_date` will fail if the specified datetime pattern is invalid. In Spark 3.0 or earlier, they result `NULL`. - - In Spark 3.1, casting numeric to timestamp will be forbidden by default. It's strongly recommended to use dedicated functions: TIMESTAMP_SECONDS, TIMESTAMP_MILLIS and TIMESTAMP_MICROS. Or you can set `spark.sql.legacy.allowCastNumericToTimestamp` to true to work around it. See more details in SPARK-31710. - ## Upgrading from Spark SQL 3.0 to 3.0.1 - In Spark 3.0, JSON datasource and JSON function `schema_of_json` infer TimestampType from string values if they match to the pattern defined by the JSON option `timestampFormat`. Since version 3.0.1, the timestamp type inference is disabled by default. Set the JSON option `inferTimestamp` to `true` to enable such type inference. diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index e5ca7e9d10d59..6488ad9cd34c9 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -127,7 +127,7 @@ By default `spark.sql.ansi.enabled` is false. Below is a list of all the keywords in Spark SQL. -|Keyword|Spark SQL
ANSI Mode|Spark SQL
Default Mode|SQL-2011| +|Keyword|Spark SQL
ANSI Mode|Spark SQL
Default Mode|SQL-2016| |-------|----------------------|-------------------------|--------| |ADD|non-reserved|non-reserved|non-reserved| |AFTER|non-reserved|non-reserved|non-reserved| @@ -149,7 +149,7 @@ Below is a list of all the keywords in Spark SQL. |BUCKETS|non-reserved|non-reserved|non-reserved| |BY|non-reserved|non-reserved|reserved| |CACHE|non-reserved|non-reserved|non-reserved| -|CASCADE|non-reserved|non-reserved|reserved| +|CASCADE|non-reserved|non-reserved|non-reserved| |CASE|reserved|non-reserved|reserved| |CAST|reserved|non-reserved|reserved| |CHANGE|non-reserved|non-reserved|non-reserved| @@ -193,7 +193,7 @@ Below is a list of all the keywords in Spark SQL. |DIRECTORY|non-reserved|non-reserved|non-reserved| |DISTINCT|reserved|non-reserved|reserved| |DISTRIBUTE|non-reserved|non-reserved|non-reserved| -|DIV|non-reserved|non-reserved|non-reserved| +|DIV|non-reserved|non-reserved|not a keyword| |DROP|non-reserved|non-reserved|reserved| |ELSE|reserved|non-reserved|reserved| |END|reserved|non-reserved|reserved| @@ -228,7 +228,7 @@ Below is a list of all the keywords in Spark SQL. |GROUPING|non-reserved|non-reserved|reserved| |HAVING|reserved|non-reserved|reserved| |HOUR|reserved|non-reserved|reserved| -|IF|non-reserved|non-reserved|reserved| +|IF|non-reserved|non-reserved|not a keyword| |IGNORE|non-reserved|non-reserved|non-reserved| |IMPORT|non-reserved|non-reserved|non-reserved| |IN|reserved|non-reserved|reserved| @@ -302,12 +302,14 @@ Below is a list of all the keywords in Spark SQL. |PROPERTIES|non-reserved|non-reserved|non-reserved| |PURGE|non-reserved|non-reserved|non-reserved| |QUERY|non-reserved|non-reserved|non-reserved| +|RANGE|non-reserved|non-reserved|reserved| |RECORDREADER|non-reserved|non-reserved|non-reserved| |RECORDWRITER|non-reserved|non-reserved|non-reserved| |RECOVER|non-reserved|non-reserved|non-reserved| |REDUCE|non-reserved|non-reserved|non-reserved| |REFERENCES|reserved|non-reserved|reserved| |REFRESH|non-reserved|non-reserved|non-reserved| +|REGEXP|non-reserved|non-reserved|not a keyword| |RENAME|non-reserved|non-reserved|non-reserved| |REPAIR|non-reserved|non-reserved|non-reserved| |REPLACE|non-reserved|non-reserved|non-reserved| @@ -323,6 +325,7 @@ Below is a list of all the keywords in Spark SQL. |ROW|non-reserved|non-reserved|reserved| |ROWS|non-reserved|non-reserved|reserved| |SCHEMA|non-reserved|non-reserved|non-reserved| +|SCHEMAS|non-reserved|non-reserved|not a keyword| |SECOND|reserved|non-reserved|reserved| |SELECT|reserved|non-reserved|reserved| |SEMI|non-reserved|strict-non-reserved|non-reserved| @@ -348,6 +351,7 @@ Below is a list of all the keywords in Spark SQL. |TABLES|non-reserved|non-reserved|non-reserved| |TABLESAMPLE|non-reserved|non-reserved|reserved| |TBLPROPERTIES|non-reserved|non-reserved|non-reserved| +|TEMP|non-reserved|non-reserved|not a keyword| |TEMPORARY|non-reserved|non-reserved|non-reserved| |TERMINATED|non-reserved|non-reserved|non-reserved| |THEN|reserved|non-reserved|reserved| @@ -360,6 +364,7 @@ Below is a list of all the keywords in Spark SQL. |TRIM|non-reserved|non-reserved|non-reserved| |TRUE|non-reserved|non-reserved|reserved| |TRUNCATE|non-reserved|non-reserved|reserved| +|TYPE|non-reserved|non-reserved|non-reserved| |UNARCHIVE|non-reserved|non-reserved|non-reserved| |UNBOUNDED|non-reserved|non-reserved|non-reserved| |UNCACHE|non-reserved|non-reserved|non-reserved| diff --git a/docs/sql-ref-literals.md b/docs/sql-ref-literals.md index b83f7f0a97c24..3dbed846d40b8 100644 --- a/docs/sql-ref-literals.md +++ b/docs/sql-ref-literals.md @@ -219,6 +219,11 @@ double literals: decimal_digits { D | exponent [ D ] } | digit [ ... ] { exponent [ D ] | [ exponent ] D } ``` +float literals: +```sql +decimal_digits { F | exponent [ F ] } | digit [ ... ] { exponent [ F ] | [ exponent ] F } +``` + While decimal_digits is defined as ```sql [ + | - ] { digit [ ... ] . [ digit [ ... ] ] | . digit [ ... ] } @@ -239,6 +244,10 @@ E [ + | - ] digit [ ... ] Case insensitive, indicates `DOUBLE`, which is an 8-byte double-precision floating point number. +* **F** + + Case insensitive, indicates `FLOAT`, which is a 4-byte single-precision floating point number. + * **BD** Case insensitive, indicates `DECIMAL`, with the total number of digits as precision and the number of digits to right of decimal point as scale. diff --git a/docs/sql-ref-syntax-qry-select-like.md b/docs/sql-ref-syntax-qry-select-like.md index feb5eb7b3c80d..6211faa8d529e 100644 --- a/docs/sql-ref-syntax-qry-select-like.md +++ b/docs/sql-ref-syntax-qry-select-like.md @@ -26,7 +26,7 @@ A LIKE predicate is used to search for a specific pattern. ### Syntax ```sql -[ NOT ] { LIKE search_pattern [ ESCAPE esc_char ] | RLIKE regex_pattern } +[ NOT ] { LIKE search_pattern [ ESCAPE esc_char ] | [ RLIKE | REGEXP ] regex_pattern } ``` ### Parameters @@ -44,7 +44,7 @@ A LIKE predicate is used to search for a specific pattern. * **regex_pattern** - Specifies a regular expression search pattern to be searched by the `RLIKE` clause. + Specifies a regular expression search pattern to be searched by the `RLIKE` or `REGEXP` clause. ### Examples @@ -90,6 +90,14 @@ SELECT * FROM person WHERE name RLIKE 'M+'; |200|Mary|null| +---+----+----+ +SELECT * FROM person WHERE name REGEXP 'M+'; ++---+----+----+ +| id|name| age| ++---+----+----+ +|300|Mike| 80| +|200|Mary|null| ++---+----+----+ + SELECT * FROM person WHERE name LIKE '%\_%'; +---+------+---+ | id| name|age| diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index db813c46949c2..c7959d4201151 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -229,7 +229,7 @@ To run the example, - Set up Kinesis stream (see earlier section) within AWS. Note the name of the Kinesis stream and the endpoint URL corresponding to the region where the stream was created. -- Set up the environment variables `AWS_ACCESS_KEY_ID` and `AWS_SECRET_KEY` with your AWS credentials. +- Set up the environment variables `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` with your AWS credentials. - In the Spark root directory, run the example as diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index ac4aa9255ae68..56a455a1b8d21 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1822,7 +1822,7 @@ This is shown in the following example.
{% highlight scala %} -object WordBlacklist { +object WordExcludeList { @volatile private var instance: Broadcast[Seq[String]] = null @@ -1830,8 +1830,8 @@ object WordBlacklist { if (instance == null) { synchronized { if (instance == null) { - val wordBlacklist = Seq("a", "b", "c") - instance = sc.broadcast(wordBlacklist) + val wordExcludeList = Seq("a", "b", "c") + instance = sc.broadcast(wordExcludeList) } } } @@ -1847,7 +1847,7 @@ object DroppedWordsCounter { if (instance == null) { synchronized { if (instance == null) { - instance = sc.longAccumulator("WordsInBlacklistCounter") + instance = sc.longAccumulator("DroppedWordsCounter") } } } @@ -1856,13 +1856,13 @@ object DroppedWordsCounter { } wordCounts.foreachRDD { (rdd: RDD[(String, Int)], time: Time) => - // Get or register the blacklist Broadcast - val blacklist = WordBlacklist.getInstance(rdd.sparkContext) + // Get or register the excludeList Broadcast + val excludeList = WordExcludeList.getInstance(rdd.sparkContext) // Get or register the droppedWordsCounter Accumulator val droppedWordsCounter = DroppedWordsCounter.getInstance(rdd.sparkContext) - // Use blacklist to drop words and use droppedWordsCounter to count them + // Use excludeList to drop words and use droppedWordsCounter to count them val counts = rdd.filter { case (word, count) => - if (blacklist.value.contains(word)) { + if (excludeList.value.contains(word)) { droppedWordsCounter.add(count) false } else { @@ -1879,16 +1879,16 @@ See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_
{% highlight java %} -class JavaWordBlacklist { +class JavaWordExcludeList { private static volatile Broadcast> instance = null; public static Broadcast> getInstance(JavaSparkContext jsc) { if (instance == null) { - synchronized (JavaWordBlacklist.class) { + synchronized (JavaWordExcludeList.class) { if (instance == null) { - List wordBlacklist = Arrays.asList("a", "b", "c"); - instance = jsc.broadcast(wordBlacklist); + List wordExcludeList = Arrays.asList("a", "b", "c"); + instance = jsc.broadcast(wordExcludeList); } } } @@ -1904,7 +1904,7 @@ class JavaDroppedWordsCounter { if (instance == null) { synchronized (JavaDroppedWordsCounter.class) { if (instance == null) { - instance = jsc.sc().longAccumulator("WordsInBlacklistCounter"); + instance = jsc.sc().longAccumulator("DroppedWordsCounter"); } } } @@ -1913,13 +1913,13 @@ class JavaDroppedWordsCounter { } wordCounts.foreachRDD((rdd, time) -> { - // Get or register the blacklist Broadcast - Broadcast> blacklist = JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); + // Get or register the excludeList Broadcast + Broadcast> excludeList = JavaWordExcludeList.getInstance(new JavaSparkContext(rdd.context())); // Get or register the droppedWordsCounter Accumulator LongAccumulator droppedWordsCounter = JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); - // Use blacklist to drop words and use droppedWordsCounter to count them + // Use excludeList to drop words and use droppedWordsCounter to count them String counts = rdd.filter(wordCount -> { - if (blacklist.value().contains(wordCount._1())) { + if (excludeList.value().contains(wordCount._1())) { droppedWordsCounter.add(wordCount._2()); return false; } else { @@ -1935,10 +1935,10 @@ See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_
{% highlight python %} -def getWordBlacklist(sparkContext): - if ("wordBlacklist" not in globals()): - globals()["wordBlacklist"] = sparkContext.broadcast(["a", "b", "c"]) - return globals()["wordBlacklist"] +def getWordExcludeList(sparkContext): + if ("wordExcludeList" not in globals()): + globals()["wordExcludeList"] = sparkContext.broadcast(["a", "b", "c"]) + return globals()["wordExcludeList"] def getDroppedWordsCounter(sparkContext): if ("droppedWordsCounter" not in globals()): @@ -1946,14 +1946,14 @@ def getDroppedWordsCounter(sparkContext): return globals()["droppedWordsCounter"] def echo(time, rdd): - # Get or register the blacklist Broadcast - blacklist = getWordBlacklist(rdd.context) + # Get or register the excludeList Broadcast + excludeList = getWordExcludeList(rdd.context) # Get or register the droppedWordsCounter Accumulator droppedWordsCounter = getDroppedWordsCounter(rdd.context) - # Use blacklist to drop words and use droppedWordsCounter to count them + # Use excludeList to drop words and use droppedWordsCounter to count them def filterFunc(wordCount): - if wordCount[0] in blacklist.value: + if wordCount[0] in excludeList.value: droppedWordsCounter.add(wordCount[1]) False else: @@ -2216,7 +2216,7 @@ In specific cases where the amount of data that needs to be retained for the str ### Task Launching Overheads {:.no_toc} If the number of tasks launched per second is high (say, 50 or more per second), then the overhead -of sending out tasks to the slaves may be significant and will make it hard to achieve sub-second +of sending out tasks to the executors may be significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: * **Execution mode**: Running Spark in Standalone mode or coarse-grained Mesos mode leads to diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java index a0979aa2d24e4..3b5d8e6d555eb 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java @@ -23,7 +23,7 @@ import java.util.Arrays; import java.util.List; -import scala.collection.mutable.WrappedArray; +import scala.collection.mutable.Seq; import org.apache.spark.ml.feature.RegexTokenizer; import org.apache.spark.ml.feature.Tokenizer; @@ -69,7 +69,7 @@ public static void main(String[] args) { .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); spark.udf().register( - "countTokens", (WrappedArray words) -> words.size(), DataTypes.IntegerType); + "countTokens", (Seq words) -> words.size(), DataTypes.IntegerType); Dataset tokenized = tokenizer.transform(sentenceDataFrame); tokenized.select("sentence", "words") diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java index 45a876decff8b..c01a62b078f7a 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java @@ -41,16 +41,16 @@ /** * Use this singleton to get or register a Broadcast variable. */ -class JavaWordBlacklist { +class JavaWordExcludeList { private static volatile Broadcast> instance = null; public static Broadcast> getInstance(JavaSparkContext jsc) { if (instance == null) { - synchronized (JavaWordBlacklist.class) { + synchronized (JavaWordExcludeList.class) { if (instance == null) { - List wordBlacklist = Arrays.asList("a", "b", "c"); - instance = jsc.broadcast(wordBlacklist); + List wordExcludeList = Arrays.asList("a", "b", "c"); + instance = jsc.broadcast(wordExcludeList); } } } @@ -69,7 +69,7 @@ public static LongAccumulator getInstance(JavaSparkContext jsc) { if (instance == null) { synchronized (JavaDroppedWordsCounter.class) { if (instance == null) { - instance = jsc.sc().longAccumulator("WordsInBlacklistCounter"); + instance = jsc.sc().longAccumulator("DroppedWordsCounter"); } } } @@ -133,15 +133,15 @@ private static JavaStreamingContext createContext(String ip, .reduceByKey((i1, i2) -> i1 + i2); wordCounts.foreachRDD((rdd, time) -> { - // Get or register the blacklist Broadcast - Broadcast> blacklist = - JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); + // Get or register the excludeList Broadcast + Broadcast> excludeList = + JavaWordExcludeList.getInstance(new JavaSparkContext(rdd.context())); // Get or register the droppedWordsCounter Accumulator LongAccumulator droppedWordsCounter = JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); - // Use blacklist to drop words and use droppedWordsCounter to count them + // Use excludeList to drop words and use droppedWordsCounter to count them String counts = rdd.filter(wordCount -> { - if (blacklist.value().contains(wordCount._1())) { + if (excludeList.value().contains(wordCount._1())) { droppedWordsCounter.add(wordCount._2()); return false; } else { diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 6d3241876ad51..511634fd8f6c2 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -21,8 +21,6 @@ This example requires numpy (http://www.numpy.org/) """ -from __future__ import print_function - import sys import numpy as np diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index a18722c687f8b..49ab37e7b3286 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -43,8 +43,6 @@ {u'favorite_color': None, u'name': u'Alyssa'} {u'favorite_color': u'red', u'name': u'Ben'} """ -from __future__ import print_function - import sys from functools import reduce diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index a42d711fc505f..022378619c97f 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -22,8 +22,6 @@ This example requires NumPy (http://www.numpy.org/). """ -from __future__ import print_function - import sys import numpy as np diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index bcc4e0f4e8eae..4b83740152ca4 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -22,8 +22,6 @@ In practice, one may prefer to use the LogisticRegression algorithm in ML, as shown in examples/src/main/python/ml/logistic_regression_with_elastic_net.py. """ -from __future__ import print_function - import sys import numpy as np diff --git a/examples/src/main/python/ml/aft_survival_regression.py b/examples/src/main/python/ml/aft_survival_regression.py index 0a71f76418ea6..2040a7876c7fa 100644 --- a/examples/src/main/python/ml/aft_survival_regression.py +++ b/examples/src/main/python/ml/aft_survival_regression.py @@ -20,8 +20,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py """ -from __future__ import print_function - # $example on$ from pyspark.ml.regression import AFTSurvivalRegression from pyspark.ml.linalg import Vectors diff --git a/examples/src/main/python/ml/als_example.py b/examples/src/main/python/ml/als_example.py index 8b7ec9c439f9f..b39263978402b 100644 --- a/examples/src/main/python/ml/als_example.py +++ b/examples/src/main/python/ml/als_example.py @@ -15,12 +15,6 @@ # limitations under the License. # -from __future__ import print_function - -import sys -if sys.version >= '3': - long = int - from pyspark.sql import SparkSession # $example on$ @@ -39,7 +33,7 @@ lines = spark.read.text("data/mllib/als/sample_movielens_ratings.txt").rdd parts = lines.map(lambda row: row.value.split("::")) ratingsRDD = parts.map(lambda p: Row(userId=int(p[0]), movieId=int(p[1]), - rating=float(p[2]), timestamp=long(p[3]))) + rating=float(p[2]), timestamp=int(p[3]))) ratings = spark.createDataFrame(ratingsRDD) (training, test) = ratings.randomSplit([0.8, 0.2]) diff --git a/examples/src/main/python/ml/anova_selector_example.py b/examples/src/main/python/ml/anova_selector_example.py index f8458f5d6e487..da80fa62316d7 100644 --- a/examples/src/main/python/ml/anova_selector_example.py +++ b/examples/src/main/python/ml/anova_selector_example.py @@ -20,8 +20,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/anova_selector_example.py """ -from __future__ import print_function - from pyspark.sql import SparkSession # $example on$ from pyspark.ml.feature import ANOVASelector diff --git a/examples/src/main/python/ml/anova_test_example.py b/examples/src/main/python/ml/anova_test_example.py index 4119441cdeab6..451e078f60e56 100644 --- a/examples/src/main/python/ml/anova_test_example.py +++ b/examples/src/main/python/ml/anova_test_example.py @@ -20,8 +20,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/anova_test_example.py """ -from __future__ import print_function - from pyspark.sql import SparkSession # $example on$ from pyspark.ml.linalg import Vectors diff --git a/examples/src/main/python/ml/binarizer_example.py b/examples/src/main/python/ml/binarizer_example.py index 669bb2aeabecd..5d5ae4122e1d4 100644 --- a/examples/src/main/python/ml/binarizer_example.py +++ b/examples/src/main/python/ml/binarizer_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark.sql import SparkSession # $example on$ from pyspark.ml.feature import Binarizer diff --git a/examples/src/main/python/ml/bisecting_k_means_example.py b/examples/src/main/python/ml/bisecting_k_means_example.py index 82adb338b5d91..513f80a09ef05 100644 --- a/examples/src/main/python/ml/bisecting_k_means_example.py +++ b/examples/src/main/python/ml/bisecting_k_means_example.py @@ -20,8 +20,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/bisecting_k_means_example.py """ -from __future__ import print_function - # $example on$ from pyspark.ml.clustering import BisectingKMeans from pyspark.ml.evaluation import ClusteringEvaluator diff --git a/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py b/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py index 610176ea596ca..f5836091f35ba 100644 --- a/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py +++ b/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py @@ -20,8 +20,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/bucketed_random_projection_lsh_example.py """ -from __future__ import print_function - # $example on$ from pyspark.ml.feature import BucketedRandomProjectionLSH from pyspark.ml.linalg import Vectors diff --git a/examples/src/main/python/ml/bucketizer_example.py b/examples/src/main/python/ml/bucketizer_example.py index 742f35093b9d2..5de67f7126b5e 100644 --- a/examples/src/main/python/ml/bucketizer_example.py +++ b/examples/src/main/python/ml/bucketizer_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark.sql import SparkSession # $example on$ from pyspark.ml.feature import Bucketizer diff --git a/examples/src/main/python/ml/chi_square_test_example.py b/examples/src/main/python/ml/chi_square_test_example.py index 2af7e683cdb72..bf15a03d9cb4c 100644 --- a/examples/src/main/python/ml/chi_square_test_example.py +++ b/examples/src/main/python/ml/chi_square_test_example.py @@ -20,8 +20,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/chi_square_test_example.py """ -from __future__ import print_function - from pyspark.sql import SparkSession # $example on$ from pyspark.ml.linalg import Vectors diff --git a/examples/src/main/python/ml/chisq_selector_example.py b/examples/src/main/python/ml/chisq_selector_example.py index 028a9ea9d67b1..c83a8c1bc7b27 100644 --- a/examples/src/main/python/ml/chisq_selector_example.py +++ b/examples/src/main/python/ml/chisq_selector_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark.sql import SparkSession # $example on$ from pyspark.ml.feature import ChiSqSelector diff --git a/examples/src/main/python/ml/correlation_example.py b/examples/src/main/python/ml/correlation_example.py index 1f4e402ac1a51..9006d541491fb 100644 --- a/examples/src/main/python/ml/correlation_example.py +++ b/examples/src/main/python/ml/correlation_example.py @@ -20,8 +20,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/correlation_example.py """ -from __future__ import print_function - # $example on$ from pyspark.ml.linalg import Vectors from pyspark.ml.stat import Correlation diff --git a/examples/src/main/python/ml/count_vectorizer_example.py b/examples/src/main/python/ml/count_vectorizer_example.py index f2e41db77d898..b3ddfb128c3d0 100644 --- a/examples/src/main/python/ml/count_vectorizer_example.py +++ b/examples/src/main/python/ml/count_vectorizer_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark.sql import SparkSession # $example on$ from pyspark.ml.feature import CountVectorizer diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py index 6256d11504afb..0ad0865486959 100644 --- a/examples/src/main/python/ml/cross_validator.py +++ b/examples/src/main/python/ml/cross_validator.py @@ -22,8 +22,6 @@ bin/spark-submit examples/src/main/python/ml/cross_validator.py """ -from __future__ import print_function - # $example on$ from pyspark.ml import Pipeline from pyspark.ml.classification import LogisticRegression diff --git a/examples/src/main/python/ml/dataframe_example.py b/examples/src/main/python/ml/dataframe_example.py index cabc3de68f2f4..d2bf93744113b 100644 --- a/examples/src/main/python/ml/dataframe_example.py +++ b/examples/src/main/python/ml/dataframe_example.py @@ -19,8 +19,6 @@ An example of how to use DataFrame for ML. Run with:: bin/spark-submit examples/src/main/python/ml/dataframe_example.py """ -from __future__ import print_function - import os import sys import tempfile diff --git a/examples/src/main/python/ml/dct_example.py b/examples/src/main/python/ml/dct_example.py index c0457f8d0f43b..37da4f5e8f1cb 100644 --- a/examples/src/main/python/ml/dct_example.py +++ b/examples/src/main/python/ml/dct_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import DCT from pyspark.ml.linalg import Vectors diff --git a/examples/src/main/python/ml/decision_tree_classification_example.py b/examples/src/main/python/ml/decision_tree_classification_example.py index d6e2977de0082..eb7177b845357 100644 --- a/examples/src/main/python/ml/decision_tree_classification_example.py +++ b/examples/src/main/python/ml/decision_tree_classification_example.py @@ -18,8 +18,6 @@ """ Decision Tree Classification Example. """ -from __future__ import print_function - # $example on$ from pyspark.ml import Pipeline from pyspark.ml.classification import DecisionTreeClassifier diff --git a/examples/src/main/python/ml/decision_tree_regression_example.py b/examples/src/main/python/ml/decision_tree_regression_example.py index 58d7ad921d8e0..1ed1636a3d962 100644 --- a/examples/src/main/python/ml/decision_tree_regression_example.py +++ b/examples/src/main/python/ml/decision_tree_regression_example.py @@ -18,8 +18,6 @@ """ Decision Tree Regression Example. """ -from __future__ import print_function - # $example on$ from pyspark.ml import Pipeline from pyspark.ml.regression import DecisionTreeRegressor diff --git a/examples/src/main/python/ml/elementwise_product_example.py b/examples/src/main/python/ml/elementwise_product_example.py index 590053998bccc..71eec8d432998 100644 --- a/examples/src/main/python/ml/elementwise_product_example.py +++ b/examples/src/main/python/ml/elementwise_product_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import ElementwiseProduct from pyspark.ml.linalg import Vectors diff --git a/examples/src/main/python/ml/estimator_transformer_param_example.py b/examples/src/main/python/ml/estimator_transformer_param_example.py index eb21051435393..1dcca6c201119 100644 --- a/examples/src/main/python/ml/estimator_transformer_param_example.py +++ b/examples/src/main/python/ml/estimator_transformer_param_example.py @@ -18,8 +18,6 @@ """ Estimator Transformer Param Example. """ -from __future__ import print_function - # $example on$ from pyspark.ml.linalg import Vectors from pyspark.ml.classification import LogisticRegression diff --git a/examples/src/main/python/ml/feature_hasher_example.py b/examples/src/main/python/ml/feature_hasher_example.py index 6cf9ecc396400..4fe573d19dfbc 100644 --- a/examples/src/main/python/ml/feature_hasher_example.py +++ b/examples/src/main/python/ml/feature_hasher_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark.sql import SparkSession # $example on$ from pyspark.ml.feature import FeatureHasher diff --git a/examples/src/main/python/ml/fm_classifier_example.py b/examples/src/main/python/ml/fm_classifier_example.py index 6e7c2ccf021ed..b47bdc5275beb 100644 --- a/examples/src/main/python/ml/fm_classifier_example.py +++ b/examples/src/main/python/ml/fm_classifier_example.py @@ -18,8 +18,6 @@ """ FMClassifier Example. """ -from __future__ import print_function - # $example on$ from pyspark.ml import Pipeline from pyspark.ml.classification import FMClassifier diff --git a/examples/src/main/python/ml/fm_regressor_example.py b/examples/src/main/python/ml/fm_regressor_example.py index afd76396800b7..5c8133996ae83 100644 --- a/examples/src/main/python/ml/fm_regressor_example.py +++ b/examples/src/main/python/ml/fm_regressor_example.py @@ -18,8 +18,6 @@ """ FMRegressor Example. """ -from __future__ import print_function - # $example on$ from pyspark.ml import Pipeline from pyspark.ml.regression import FMRegressor diff --git a/examples/src/main/python/ml/fvalue_selector_example.py b/examples/src/main/python/ml/fvalue_selector_example.py index 3158953a5dfc4..f164af47eb309 100644 --- a/examples/src/main/python/ml/fvalue_selector_example.py +++ b/examples/src/main/python/ml/fvalue_selector_example.py @@ -20,8 +20,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/fvalue_selector_example.py """ -from __future__ import print_function - from pyspark.sql import SparkSession # $example on$ from pyspark.ml.feature import FValueSelector diff --git a/examples/src/main/python/ml/fvalue_test_example.py b/examples/src/main/python/ml/fvalue_test_example.py index 410b39e4493f8..dfa8073e5afc9 100644 --- a/examples/src/main/python/ml/fvalue_test_example.py +++ b/examples/src/main/python/ml/fvalue_test_example.py @@ -20,8 +20,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/fvalue_test_example.py """ -from __future__ import print_function - from pyspark.sql import SparkSession # $example on$ from pyspark.ml.linalg import Vectors diff --git a/examples/src/main/python/ml/gaussian_mixture_example.py b/examples/src/main/python/ml/gaussian_mixture_example.py index 4938a904189f9..1441faa792983 100644 --- a/examples/src/main/python/ml/gaussian_mixture_example.py +++ b/examples/src/main/python/ml/gaussian_mixture_example.py @@ -20,8 +20,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/gaussian_mixture_example.py """ -from __future__ import print_function - # $example on$ from pyspark.ml.clustering import GaussianMixture # $example off$ diff --git a/examples/src/main/python/ml/generalized_linear_regression_example.py b/examples/src/main/python/ml/generalized_linear_regression_example.py index a52f4650c1c6f..06a8a5a2e9428 100644 --- a/examples/src/main/python/ml/generalized_linear_regression_example.py +++ b/examples/src/main/python/ml/generalized_linear_regression_example.py @@ -20,8 +20,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/generalized_linear_regression_example.py """ -from __future__ import print_function - from pyspark.sql import SparkSession # $example on$ from pyspark.ml.regression import GeneralizedLinearRegression diff --git a/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py b/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py index c2042fd7b7b07..a7efa2170a069 100644 --- a/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py +++ b/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py @@ -18,8 +18,6 @@ """ Gradient Boosted Tree Classifier Example. """ -from __future__ import print_function - # $example on$ from pyspark.ml import Pipeline from pyspark.ml.classification import GBTClassifier diff --git a/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py b/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py index cc96c973e4b23..5e09b96c1ea3a 100644 --- a/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py +++ b/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py @@ -18,8 +18,6 @@ """ Gradient Boosted Tree Regressor Example. """ -from __future__ import print_function - # $example on$ from pyspark.ml import Pipeline from pyspark.ml.regression import GBTRegressor diff --git a/examples/src/main/python/ml/index_to_string_example.py b/examples/src/main/python/ml/index_to_string_example.py index 33d104e8e3f41..98bdb89ce3039 100644 --- a/examples/src/main/python/ml/index_to_string_example.py +++ b/examples/src/main/python/ml/index_to_string_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import IndexToString, StringIndexer # $example off$ diff --git a/examples/src/main/python/ml/interaction_example.py b/examples/src/main/python/ml/interaction_example.py index 4b632271916f5..ac365179b0c20 100644 --- a/examples/src/main/python/ml/interaction_example.py +++ b/examples/src/main/python/ml/interaction_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import Interaction, VectorAssembler # $example off$ diff --git a/examples/src/main/python/ml/isotonic_regression_example.py b/examples/src/main/python/ml/isotonic_regression_example.py index 89cba9dfc7e8f..d7b893894fc71 100644 --- a/examples/src/main/python/ml/isotonic_regression_example.py +++ b/examples/src/main/python/ml/isotonic_regression_example.py @@ -21,8 +21,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/isotonic_regression_example.py """ -from __future__ import print_function - # $example on$ from pyspark.ml.regression import IsotonicRegression # $example off$ diff --git a/examples/src/main/python/ml/kmeans_example.py b/examples/src/main/python/ml/kmeans_example.py index 80a878af679f4..47223fd953d17 100644 --- a/examples/src/main/python/ml/kmeans_example.py +++ b/examples/src/main/python/ml/kmeans_example.py @@ -22,8 +22,6 @@ This example requires NumPy (http://www.numpy.org/). """ -from __future__ import print_function - # $example on$ from pyspark.ml.clustering import KMeans from pyspark.ml.evaluation import ClusteringEvaluator diff --git a/examples/src/main/python/ml/lda_example.py b/examples/src/main/python/ml/lda_example.py index 97d1a042d1479..a47dfa383c895 100644 --- a/examples/src/main/python/ml/lda_example.py +++ b/examples/src/main/python/ml/lda_example.py @@ -20,8 +20,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/lda_example.py """ -from __future__ import print_function - # $example on$ from pyspark.ml.clustering import LDA # $example off$ diff --git a/examples/src/main/python/ml/linear_regression_with_elastic_net.py b/examples/src/main/python/ml/linear_regression_with_elastic_net.py index 6639e9160ab71..864fc76cff132 100644 --- a/examples/src/main/python/ml/linear_regression_with_elastic_net.py +++ b/examples/src/main/python/ml/linear_regression_with_elastic_net.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.regression import LinearRegression # $example off$ diff --git a/examples/src/main/python/ml/linearsvc.py b/examples/src/main/python/ml/linearsvc.py index 9b79abbf96f88..61d726cf3f1ae 100644 --- a/examples/src/main/python/ml/linearsvc.py +++ b/examples/src/main/python/ml/linearsvc.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.classification import LinearSVC # $example off$ diff --git a/examples/src/main/python/ml/logistic_regression_summary_example.py b/examples/src/main/python/ml/logistic_regression_summary_example.py index 2274ff707b2a3..6d045108da0aa 100644 --- a/examples/src/main/python/ml/logistic_regression_summary_example.py +++ b/examples/src/main/python/ml/logistic_regression_summary_example.py @@ -20,8 +20,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/logistic_regression_summary_example.py """ -from __future__ import print_function - # $example on$ from pyspark.ml.classification import LogisticRegression # $example off$ diff --git a/examples/src/main/python/ml/logistic_regression_with_elastic_net.py b/examples/src/main/python/ml/logistic_regression_with_elastic_net.py index d095fbd373408..916fdade27623 100644 --- a/examples/src/main/python/ml/logistic_regression_with_elastic_net.py +++ b/examples/src/main/python/ml/logistic_regression_with_elastic_net.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.classification import LogisticRegression # $example off$ diff --git a/examples/src/main/python/ml/max_abs_scaler_example.py b/examples/src/main/python/ml/max_abs_scaler_example.py index 45eda3cdadde3..d7ff3561ce429 100644 --- a/examples/src/main/python/ml/max_abs_scaler_example.py +++ b/examples/src/main/python/ml/max_abs_scaler_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import MaxAbsScaler from pyspark.ml.linalg import Vectors diff --git a/examples/src/main/python/ml/min_hash_lsh_example.py b/examples/src/main/python/ml/min_hash_lsh_example.py index 93136e6ae3cae..683f97a055ede 100644 --- a/examples/src/main/python/ml/min_hash_lsh_example.py +++ b/examples/src/main/python/ml/min_hash_lsh_example.py @@ -20,8 +20,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/min_hash_lsh_example.py """ -from __future__ import print_function - # $example on$ from pyspark.ml.feature import MinHashLSH from pyspark.ml.linalg import Vectors diff --git a/examples/src/main/python/ml/min_max_scaler_example.py b/examples/src/main/python/ml/min_max_scaler_example.py index b5f272e59bc30..cd74243699894 100644 --- a/examples/src/main/python/ml/min_max_scaler_example.py +++ b/examples/src/main/python/ml/min_max_scaler_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import MinMaxScaler from pyspark.ml.linalg import Vectors diff --git a/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py index bec9860c79a2d..3bb4a72864101 100644 --- a/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py +++ b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.classification import LogisticRegression # $example off$ diff --git a/examples/src/main/python/ml/multilayer_perceptron_classification.py b/examples/src/main/python/ml/multilayer_perceptron_classification.py index 88fc69f753953..74f532193573d 100644 --- a/examples/src/main/python/ml/multilayer_perceptron_classification.py +++ b/examples/src/main/python/ml/multilayer_perceptron_classification.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.classification import MultilayerPerceptronClassifier from pyspark.ml.evaluation import MulticlassClassificationEvaluator diff --git a/examples/src/main/python/ml/n_gram_example.py b/examples/src/main/python/ml/n_gram_example.py index 31676e076a11b..8c8031b939458 100644 --- a/examples/src/main/python/ml/n_gram_example.py +++ b/examples/src/main/python/ml/n_gram_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import NGram # $example off$ diff --git a/examples/src/main/python/ml/naive_bayes_example.py b/examples/src/main/python/ml/naive_bayes_example.py index 7290ab81cd0ec..8d1777c6f9e39 100644 --- a/examples/src/main/python/ml/naive_bayes_example.py +++ b/examples/src/main/python/ml/naive_bayes_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.classification import NaiveBayes from pyspark.ml.evaluation import MulticlassClassificationEvaluator diff --git a/examples/src/main/python/ml/normalizer_example.py b/examples/src/main/python/ml/normalizer_example.py index 510bd825fd286..2aa012961a2ee 100644 --- a/examples/src/main/python/ml/normalizer_example.py +++ b/examples/src/main/python/ml/normalizer_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import Normalizer from pyspark.ml.linalg import Vectors diff --git a/examples/src/main/python/ml/one_vs_rest_example.py b/examples/src/main/python/ml/one_vs_rest_example.py index 956e94ae4ab62..4cae1a99808e8 100644 --- a/examples/src/main/python/ml/one_vs_rest_example.py +++ b/examples/src/main/python/ml/one_vs_rest_example.py @@ -21,8 +21,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py """ -from __future__ import print_function - # $example on$ from pyspark.ml.classification import LogisticRegression, OneVsRest from pyspark.ml.evaluation import MulticlassClassificationEvaluator diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_example.py index 73775b79e36cb..6deb84ed785ca 100644 --- a/examples/src/main/python/ml/onehot_encoder_example.py +++ b/examples/src/main/python/ml/onehot_encoder_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import OneHotEncoder # $example off$ diff --git a/examples/src/main/python/ml/pca_example.py b/examples/src/main/python/ml/pca_example.py index 38746aced096a..03fb709c8e91d 100644 --- a/examples/src/main/python/ml/pca_example.py +++ b/examples/src/main/python/ml/pca_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import PCA from pyspark.ml.linalg import Vectors diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py index 40bcb7b13a3de..75f436e768dc5 100644 --- a/examples/src/main/python/ml/polynomial_expansion_example.py +++ b/examples/src/main/python/ml/polynomial_expansion_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import PolynomialExpansion from pyspark.ml.linalg import Vectors diff --git a/examples/src/main/python/ml/quantile_discretizer_example.py b/examples/src/main/python/ml/quantile_discretizer_example.py index 0fc1d1949a77d..82be3936d2598 100644 --- a/examples/src/main/python/ml/quantile_discretizer_example.py +++ b/examples/src/main/python/ml/quantile_discretizer_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import QuantileDiscretizer # $example off$ diff --git a/examples/src/main/python/ml/random_forest_classifier_example.py b/examples/src/main/python/ml/random_forest_classifier_example.py index 4eaa94dd7f489..8983d1f2e979b 100644 --- a/examples/src/main/python/ml/random_forest_classifier_example.py +++ b/examples/src/main/python/ml/random_forest_classifier_example.py @@ -18,8 +18,6 @@ """ Random Forest Classifier Example. """ -from __future__ import print_function - # $example on$ from pyspark.ml import Pipeline from pyspark.ml.classification import RandomForestClassifier diff --git a/examples/src/main/python/ml/random_forest_regressor_example.py b/examples/src/main/python/ml/random_forest_regressor_example.py index a34edff2ecaa2..b9306ddf2f82c 100644 --- a/examples/src/main/python/ml/random_forest_regressor_example.py +++ b/examples/src/main/python/ml/random_forest_regressor_example.py @@ -18,8 +18,6 @@ """ Random Forest Regressor Example. """ -from __future__ import print_function - # $example on$ from pyspark.ml import Pipeline from pyspark.ml.regression import RandomForestRegressor diff --git a/examples/src/main/python/ml/rformula_example.py b/examples/src/main/python/ml/rformula_example.py index 6629239db29ec..25bb6dac56e81 100644 --- a/examples/src/main/python/ml/rformula_example.py +++ b/examples/src/main/python/ml/rformula_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import RFormula # $example off$ diff --git a/examples/src/main/python/ml/robust_scaler_example.py b/examples/src/main/python/ml/robust_scaler_example.py index 435e9ccb806c6..9f7c6d6507c78 100644 --- a/examples/src/main/python/ml/robust_scaler_example.py +++ b/examples/src/main/python/ml/robust_scaler_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import RobustScaler # $example off$ diff --git a/examples/src/main/python/ml/sql_transformer.py b/examples/src/main/python/ml/sql_transformer.py index 0bf8f35720c95..c8ac5c46aa5e9 100644 --- a/examples/src/main/python/ml/sql_transformer.py +++ b/examples/src/main/python/ml/sql_transformer.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import SQLTransformer # $example off$ diff --git a/examples/src/main/python/ml/standard_scaler_example.py b/examples/src/main/python/ml/standard_scaler_example.py index c0027480e69b3..9021c10075d81 100644 --- a/examples/src/main/python/ml/standard_scaler_example.py +++ b/examples/src/main/python/ml/standard_scaler_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import StandardScaler # $example off$ diff --git a/examples/src/main/python/ml/stopwords_remover_example.py b/examples/src/main/python/ml/stopwords_remover_example.py index 3b8e7855e3e79..832a7c7d0ad88 100644 --- a/examples/src/main/python/ml/stopwords_remover_example.py +++ b/examples/src/main/python/ml/stopwords_remover_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import StopWordsRemover # $example off$ diff --git a/examples/src/main/python/ml/string_indexer_example.py b/examples/src/main/python/ml/string_indexer_example.py index 2255bfb9c1a60..f2ac63eabd71c 100644 --- a/examples/src/main/python/ml/string_indexer_example.py +++ b/examples/src/main/python/ml/string_indexer_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import StringIndexer # $example off$ diff --git a/examples/src/main/python/ml/summarizer_example.py b/examples/src/main/python/ml/summarizer_example.py index 8835f189a1ad4..4982746450132 100644 --- a/examples/src/main/python/ml/summarizer_example.py +++ b/examples/src/main/python/ml/summarizer_example.py @@ -20,8 +20,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/summarizer_example.py """ -from __future__ import print_function - from pyspark.sql import SparkSession # $example on$ from pyspark.ml.stat import Summarizer diff --git a/examples/src/main/python/ml/tf_idf_example.py b/examples/src/main/python/ml/tf_idf_example.py index d43244fa68e97..b4bb0dfa3183c 100644 --- a/examples/src/main/python/ml/tf_idf_example.py +++ b/examples/src/main/python/ml/tf_idf_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import HashingTF, IDF, Tokenizer # $example off$ diff --git a/examples/src/main/python/ml/tokenizer_example.py b/examples/src/main/python/ml/tokenizer_example.py index 5c65c5c9f8260..c6b5fac227315 100644 --- a/examples/src/main/python/ml/tokenizer_example.py +++ b/examples/src/main/python/ml/tokenizer_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import Tokenizer, RegexTokenizer from pyspark.sql.functions import col, udf diff --git a/examples/src/main/python/ml/variance_threshold_selector_example.py b/examples/src/main/python/ml/variance_threshold_selector_example.py index b7edb86653530..0a996e0e28264 100644 --- a/examples/src/main/python/ml/variance_threshold_selector_example.py +++ b/examples/src/main/python/ml/variance_threshold_selector_example.py @@ -20,8 +20,6 @@ Run with: bin/spark-submit examples/src/main/python/ml/variance_threshold_selector_example.py """ -from __future__ import print_function - from pyspark.sql import SparkSession # $example on$ from pyspark.ml.feature import VarianceThresholdSelector diff --git a/examples/src/main/python/ml/vector_assembler_example.py b/examples/src/main/python/ml/vector_assembler_example.py index 98de1d5ea7dac..0ce31cf0eabc9 100644 --- a/examples/src/main/python/ml/vector_assembler_example.py +++ b/examples/src/main/python/ml/vector_assembler_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.linalg import Vectors from pyspark.ml.feature import VectorAssembler diff --git a/examples/src/main/python/ml/vector_indexer_example.py b/examples/src/main/python/ml/vector_indexer_example.py index 5c2956077d6ce..51a4191606fb8 100644 --- a/examples/src/main/python/ml/vector_indexer_example.py +++ b/examples/src/main/python/ml/vector_indexer_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import VectorIndexer # $example off$ diff --git a/examples/src/main/python/ml/vector_size_hint_example.py b/examples/src/main/python/ml/vector_size_hint_example.py index fb77dacec629d..355d85aee8729 100644 --- a/examples/src/main/python/ml/vector_size_hint_example.py +++ b/examples/src/main/python/ml/vector_size_hint_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.linalg import Vectors from pyspark.ml.feature import (VectorSizeHint, VectorAssembler) diff --git a/examples/src/main/python/ml/vector_slicer_example.py b/examples/src/main/python/ml/vector_slicer_example.py index 68c8cfe27e375..86e089d152c5a 100644 --- a/examples/src/main/python/ml/vector_slicer_example.py +++ b/examples/src/main/python/ml/vector_slicer_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import VectorSlicer from pyspark.ml.linalg import Vectors diff --git a/examples/src/main/python/ml/word2vec_example.py b/examples/src/main/python/ml/word2vec_example.py index 77f8951df0883..0eabeda3dce4b 100644 --- a/examples/src/main/python/ml/word2vec_example.py +++ b/examples/src/main/python/ml/word2vec_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from pyspark.ml.feature import Word2Vec # $example off$ diff --git a/examples/src/main/python/mllib/binary_classification_metrics_example.py b/examples/src/main/python/mllib/binary_classification_metrics_example.py index d14ce7982e24f..741746e6e35ae 100644 --- a/examples/src/main/python/mllib/binary_classification_metrics_example.py +++ b/examples/src/main/python/mllib/binary_classification_metrics_example.py @@ -17,7 +17,6 @@ """ Binary Classification Metrics Example. """ -from __future__ import print_function from pyspark import SparkContext # $example on$ from pyspark.mllib.classification import LogisticRegressionWithLBFGS diff --git a/examples/src/main/python/mllib/bisecting_k_means_example.py b/examples/src/main/python/mllib/bisecting_k_means_example.py index 36e36fc6897f3..d7b6ad9d424a6 100644 --- a/examples/src/main/python/mllib/bisecting_k_means_example.py +++ b/examples/src/main/python/mllib/bisecting_k_means_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from numpy import array # $example off$ diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py index 089504fa7064b..27d07b22a5645 100755 --- a/examples/src/main/python/mllib/correlations.py +++ b/examples/src/main/python/mllib/correlations.py @@ -18,8 +18,6 @@ """ Correlations using MLlib. """ -from __future__ import print_function - import sys from pyspark import SparkContext diff --git a/examples/src/main/python/mllib/correlations_example.py b/examples/src/main/python/mllib/correlations_example.py index 66d18f6e5df17..bb71b968687cb 100644 --- a/examples/src/main/python/mllib/correlations_example.py +++ b/examples/src/main/python/mllib/correlations_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - import numpy as np from pyspark import SparkContext diff --git a/examples/src/main/python/mllib/decision_tree_classification_example.py b/examples/src/main/python/mllib/decision_tree_classification_example.py index 7eecf500584ad..009e393226c01 100644 --- a/examples/src/main/python/mllib/decision_tree_classification_example.py +++ b/examples/src/main/python/mllib/decision_tree_classification_example.py @@ -18,8 +18,6 @@ """ Decision Tree Classification Example. """ -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.tree import DecisionTree, DecisionTreeModel diff --git a/examples/src/main/python/mllib/decision_tree_regression_example.py b/examples/src/main/python/mllib/decision_tree_regression_example.py index acf9e25fdf31c..71dfbf0790175 100644 --- a/examples/src/main/python/mllib/decision_tree_regression_example.py +++ b/examples/src/main/python/mllib/decision_tree_regression_example.py @@ -18,8 +18,6 @@ """ Decision Tree Regression Example. """ -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.tree import DecisionTree, DecisionTreeModel diff --git a/examples/src/main/python/mllib/elementwise_product_example.py b/examples/src/main/python/mllib/elementwise_product_example.py index 8ae9afb1dc477..15e6a43f736cf 100644 --- a/examples/src/main/python/mllib/elementwise_product_example.py +++ b/examples/src/main/python/mllib/elementwise_product_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.feature import ElementwiseProduct diff --git a/examples/src/main/python/mllib/gaussian_mixture_example.py b/examples/src/main/python/mllib/gaussian_mixture_example.py index a60e799d62eb1..3b19478f457ec 100644 --- a/examples/src/main/python/mllib/gaussian_mixture_example.py +++ b/examples/src/main/python/mllib/gaussian_mixture_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from numpy import array # $example off$ diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py index 6b46e27ddaaa8..96ce6b6f6ab25 100644 --- a/examples/src/main/python/mllib/gaussian_mixture_model.py +++ b/examples/src/main/python/mllib/gaussian_mixture_model.py @@ -18,11 +18,6 @@ """ A Gaussian Mixture Model clustering program using MLlib. """ -from __future__ import print_function - -import sys -if sys.version >= '3': - long = int import random import argparse @@ -53,7 +48,7 @@ def parseVector(line): parser.add_argument('--convergenceTol', default=1e-3, type=float, help='convergence threshold') parser.add_argument('--maxIterations', default=100, type=int, help='Number of iterations') parser.add_argument('--seed', default=random.getrandbits(19), - type=long, help='Random seed') + type=int, help='Random seed') args = parser.parse_args() conf = SparkConf().setAppName("GMM") diff --git a/examples/src/main/python/mllib/gradient_boosting_classification_example.py b/examples/src/main/python/mllib/gradient_boosting_classification_example.py index 65a03572be9b5..eb12f206196fe 100644 --- a/examples/src/main/python/mllib/gradient_boosting_classification_example.py +++ b/examples/src/main/python/mllib/gradient_boosting_classification_example.py @@ -18,8 +18,6 @@ """ Gradient Boosted Trees Classification Example. """ -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel diff --git a/examples/src/main/python/mllib/gradient_boosting_regression_example.py b/examples/src/main/python/mllib/gradient_boosting_regression_example.py index 877f8ab461ccd..eb59a992df539 100644 --- a/examples/src/main/python/mllib/gradient_boosting_regression_example.py +++ b/examples/src/main/python/mllib/gradient_boosting_regression_example.py @@ -18,8 +18,6 @@ """ Gradient Boosted Trees Regression Example. """ -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel diff --git a/examples/src/main/python/mllib/hypothesis_testing_example.py b/examples/src/main/python/mllib/hypothesis_testing_example.py index 21a5584fd6e06..321be8b76f1b9 100644 --- a/examples/src/main/python/mllib/hypothesis_testing_example.py +++ b/examples/src/main/python/mllib/hypothesis_testing_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.linalg import Matrices, Vectors diff --git a/examples/src/main/python/mllib/hypothesis_testing_kolmogorov_smirnov_test_example.py b/examples/src/main/python/mllib/hypothesis_testing_kolmogorov_smirnov_test_example.py index ef380dee79d3d..12a186900e358 100644 --- a/examples/src/main/python/mllib/hypothesis_testing_kolmogorov_smirnov_test_example.py +++ b/examples/src/main/python/mllib/hypothesis_testing_kolmogorov_smirnov_test_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.stat import Statistics diff --git a/examples/src/main/python/mllib/isotonic_regression_example.py b/examples/src/main/python/mllib/isotonic_regression_example.py index f5322d79c45ba..a5a0cfeae9d75 100644 --- a/examples/src/main/python/mllib/isotonic_regression_example.py +++ b/examples/src/main/python/mllib/isotonic_regression_example.py @@ -18,8 +18,6 @@ """ Isotonic Regression Example. """ -from __future__ import print_function - from pyspark import SparkContext # $example on$ import math diff --git a/examples/src/main/python/mllib/k_means_example.py b/examples/src/main/python/mllib/k_means_example.py index d6058f45020c4..ead1e56de55c6 100644 --- a/examples/src/main/python/mllib/k_means_example.py +++ b/examples/src/main/python/mllib/k_means_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - # $example on$ from numpy import array from math import sqrt diff --git a/examples/src/main/python/mllib/kernel_density_estimation_example.py b/examples/src/main/python/mllib/kernel_density_estimation_example.py index 3e8f7241a4a1e..22d191716057c 100644 --- a/examples/src/main/python/mllib/kernel_density_estimation_example.py +++ b/examples/src/main/python/mllib/kernel_density_estimation_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.stat import KernelDensity diff --git a/examples/src/main/python/mllib/kmeans.py b/examples/src/main/python/mllib/kmeans.py index 1bdb3e9b4a2af..2560384b6a0e2 100755 --- a/examples/src/main/python/mllib/kmeans.py +++ b/examples/src/main/python/mllib/kmeans.py @@ -20,8 +20,6 @@ This example requires NumPy (http://www.numpy.org/). """ -from __future__ import print_function - import sys import numpy as np diff --git a/examples/src/main/python/mllib/latent_dirichlet_allocation_example.py b/examples/src/main/python/mllib/latent_dirichlet_allocation_example.py index 2a1bef5f207b7..f82a28aadc5a3 100644 --- a/examples/src/main/python/mllib/latent_dirichlet_allocation_example.py +++ b/examples/src/main/python/mllib/latent_dirichlet_allocation_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.clustering import LDA, LDAModel diff --git a/examples/src/main/python/mllib/linear_regression_with_sgd_example.py b/examples/src/main/python/mllib/linear_regression_with_sgd_example.py index 6744463d40ef1..cb67396332312 100644 --- a/examples/src/main/python/mllib/linear_regression_with_sgd_example.py +++ b/examples/src/main/python/mllib/linear_regression_with_sgd_example.py @@ -18,8 +18,6 @@ """ Linear Regression With SGD Example. """ -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index 87efe17375226..7b90615a53424 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -20,8 +20,6 @@ This example requires NumPy (http://www.numpy.org/). """ -from __future__ import print_function - import sys from pyspark import SparkContext diff --git a/examples/src/main/python/mllib/logistic_regression_with_lbfgs_example.py b/examples/src/main/python/mllib/logistic_regression_with_lbfgs_example.py index c9b768b3147d2..ac5ab1d1b5d91 100644 --- a/examples/src/main/python/mllib/logistic_regression_with_lbfgs_example.py +++ b/examples/src/main/python/mllib/logistic_regression_with_lbfgs_example.py @@ -18,8 +18,6 @@ """ Logistic Regression With LBFGS Example. """ -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel diff --git a/examples/src/main/python/mllib/naive_bayes_example.py b/examples/src/main/python/mllib/naive_bayes_example.py index a29fcccac5bfc..74d18233d533a 100644 --- a/examples/src/main/python/mllib/naive_bayes_example.py +++ b/examples/src/main/python/mllib/naive_bayes_example.py @@ -22,8 +22,6 @@ `spark-submit --master local[4] examples/src/main/python/mllib/naive_bayes_example.py` """ -from __future__ import print_function - import shutil from pyspark import SparkContext diff --git a/examples/src/main/python/mllib/normalizer_example.py b/examples/src/main/python/mllib/normalizer_example.py index a4e028ca9af8b..d46110d9a0300 100644 --- a/examples/src/main/python/mllib/normalizer_example.py +++ b/examples/src/main/python/mllib/normalizer_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.feature import Normalizer diff --git a/examples/src/main/python/mllib/power_iteration_clustering_example.py b/examples/src/main/python/mllib/power_iteration_clustering_example.py index ca19c0ccb60c8..60eedef5fab30 100644 --- a/examples/src/main/python/mllib/power_iteration_clustering_example.py +++ b/examples/src/main/python/mllib/power_iteration_clustering_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.clustering import PowerIterationClustering, PowerIterationClusteringModel diff --git a/examples/src/main/python/mllib/random_forest_classification_example.py b/examples/src/main/python/mllib/random_forest_classification_example.py index 5ac67520daee0..a929c10d5a573 100644 --- a/examples/src/main/python/mllib/random_forest_classification_example.py +++ b/examples/src/main/python/mllib/random_forest_classification_example.py @@ -18,8 +18,6 @@ """ Random Forest Classification Example. """ -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.tree import RandomForest, RandomForestModel diff --git a/examples/src/main/python/mllib/random_forest_regression_example.py b/examples/src/main/python/mllib/random_forest_regression_example.py index 7e986a0d307f0..4e05937768211 100644 --- a/examples/src/main/python/mllib/random_forest_regression_example.py +++ b/examples/src/main/python/mllib/random_forest_regression_example.py @@ -18,8 +18,6 @@ """ Random Forest Regression Example. """ -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.tree import RandomForest, RandomForestModel diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py index 9a429b5f8abdf..49afcfe9391ab 100755 --- a/examples/src/main/python/mllib/random_rdd_generation.py +++ b/examples/src/main/python/mllib/random_rdd_generation.py @@ -18,8 +18,6 @@ """ Randomly generated RDDs. """ -from __future__ import print_function - import sys from pyspark import SparkContext diff --git a/examples/src/main/python/mllib/recommendation_example.py b/examples/src/main/python/mllib/recommendation_example.py index 00e683c3ae938..719f3f904b246 100644 --- a/examples/src/main/python/mllib/recommendation_example.py +++ b/examples/src/main/python/mllib/recommendation_example.py @@ -18,8 +18,6 @@ """ Collaborative Filtering Classification Example. """ -from __future__ import print_function - from pyspark import SparkContext # $example on$ diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py index 00e7cf4bbcdbf..9095c2b2d70d6 100755 --- a/examples/src/main/python/mllib/sampled_rdds.py +++ b/examples/src/main/python/mllib/sampled_rdds.py @@ -18,8 +18,6 @@ """ Randomly sampled RDDs. """ -from __future__ import print_function - import sys from pyspark import SparkContext diff --git a/examples/src/main/python/mllib/standard_scaler_example.py b/examples/src/main/python/mllib/standard_scaler_example.py index 11ed34427dfe2..c8fd64dfbbf4a 100644 --- a/examples/src/main/python/mllib/standard_scaler_example.py +++ b/examples/src/main/python/mllib/standard_scaler_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.feature import StandardScaler diff --git a/examples/src/main/python/mllib/stratified_sampling_example.py b/examples/src/main/python/mllib/stratified_sampling_example.py index a13f8f08dd68b..2d29f74a19c1a 100644 --- a/examples/src/main/python/mllib/stratified_sampling_example.py +++ b/examples/src/main/python/mllib/stratified_sampling_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark import SparkContext if __name__ == "__main__": diff --git a/examples/src/main/python/mllib/streaming_k_means_example.py b/examples/src/main/python/mllib/streaming_k_means_example.py index e82509ad3ffb6..4904a9ebcf544 100644 --- a/examples/src/main/python/mllib/streaming_k_means_example.py +++ b/examples/src/main/python/mllib/streaming_k_means_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark import SparkContext from pyspark.streaming import StreamingContext # $example on$ diff --git a/examples/src/main/python/mllib/streaming_linear_regression_example.py b/examples/src/main/python/mllib/streaming_linear_regression_example.py index 714c9a0de7217..1d52e00fbfb5e 100644 --- a/examples/src/main/python/mllib/streaming_linear_regression_example.py +++ b/examples/src/main/python/mllib/streaming_linear_regression_example.py @@ -18,8 +18,6 @@ """ Streaming Linear Regression Example. """ -from __future__ import print_function - # $example on$ import sys # $example off$ diff --git a/examples/src/main/python/mllib/summary_statistics_example.py b/examples/src/main/python/mllib/summary_statistics_example.py index d55d1a2c2d0e1..d86e841145501 100644 --- a/examples/src/main/python/mllib/summary_statistics_example.py +++ b/examples/src/main/python/mllib/summary_statistics_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark import SparkContext # $example on$ import numpy as np diff --git a/examples/src/main/python/mllib/tf_idf_example.py b/examples/src/main/python/mllib/tf_idf_example.py index b66412b2334e7..4449066f5b0a6 100644 --- a/examples/src/main/python/mllib/tf_idf_example.py +++ b/examples/src/main/python/mllib/tf_idf_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.feature import HashingTF, IDF diff --git a/examples/src/main/python/mllib/word2vec.py b/examples/src/main/python/mllib/word2vec.py index 4e7d4f7610c24..3e5720b4df4d6 100644 --- a/examples/src/main/python/mllib/word2vec.py +++ b/examples/src/main/python/mllib/word2vec.py @@ -23,8 +23,6 @@ # grep -o -E '\w+(\W+\w+){0,15}' text8 > text8_lines # This was done so that the example can be run in local mode -from __future__ import print_function - import sys from pyspark import SparkContext diff --git a/examples/src/main/python/mllib/word2vec_example.py b/examples/src/main/python/mllib/word2vec_example.py index ad1090c77ee11..d37a6e7137b8f 100644 --- a/examples/src/main/python/mllib/word2vec_example.py +++ b/examples/src/main/python/mllib/word2vec_example.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from pyspark import SparkContext # $example on$ from pyspark.mllib.feature import Word2Vec diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index 2c19e8700ab16..0ab7249a82185 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -22,8 +22,6 @@ Example Usage: bin/spark-submit examples/src/main/python/pagerank.py data/mllib/pagerank_data.txt 10 """ -from __future__ import print_function - import re import sys from operator import add diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index 83041f0040a0c..ca8dd25e6dabf 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -29,8 +29,6 @@ {u'favorite_color': u'red', u'name': u'Ben', u'favorite_numbers': []} <...more log output...> """ -from __future__ import print_function - import sys from pyspark.sql import SparkSession diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index 5839cc2874956..e646722533f68 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - import sys from random import random from operator import add diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index d3cd985d197e3..9efb00a6f1532 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - import sys from pyspark.sql import SparkSession diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index b7d8467172fab..e46449dbefbcd 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -21,21 +21,12 @@ ./bin/spark-submit examples/src/main/python/sql/arrow.py """ -from __future__ import print_function - -import sys - from pyspark.sql import SparkSession from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version require_minimum_pandas_version() require_minimum_pyarrow_version() -if sys.version_info < (3, 6): - raise Exception( - "Running this example file requires Python 3.6+; however, " - "your Python version was:\n %s" % sys.version) - def dataframe_with_arrow_example(spark): # $example on:dataframe_with_arrow$ diff --git a/examples/src/main/python/sql/basic.py b/examples/src/main/python/sql/basic.py index c8fb25d0533b5..eba8e6ad99d17 100644 --- a/examples/src/main/python/sql/basic.py +++ b/examples/src/main/python/sql/basic.py @@ -20,8 +20,6 @@ Run with: ./bin/spark-submit examples/src/main/python/sql/basic.py """ -from __future__ import print_function - # $example on:init_session$ from pyspark.sql import SparkSession # $example off:init_session$ diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index 265f135e1e5f2..94a41a7e5e7b4 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -20,8 +20,6 @@ Run with: ./bin/spark-submit examples/src/main/python/sql/datasource.py """ -from __future__ import print_function - from pyspark.sql import SparkSession # $example on:schema_merging$ from pyspark.sql import Row diff --git a/examples/src/main/python/sql/hive.py b/examples/src/main/python/sql/hive.py index e96a8af71adc3..bc23dcd9bd2b2 100644 --- a/examples/src/main/python/sql/hive.py +++ b/examples/src/main/python/sql/hive.py @@ -20,8 +20,6 @@ Run with: ./bin/spark-submit examples/src/main/python/sql/hive.py """ -from __future__ import print_function - # $example on:spark_hive$ from os.path import join, abspath diff --git a/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py b/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py index 921067891352a..40a955a46c9b9 100644 --- a/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py +++ b/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py @@ -36,8 +36,6 @@ `$ bin/spark-submit examples/src/main/python/sql/streaming/structured_kafka_wordcount.py \ host1:port1,host2:port2 subscribe topic1,topic2` """ -from __future__ import print_function - import sys from pyspark.sql import SparkSession diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount.py b/examples/src/main/python/sql/streaming/structured_network_wordcount.py index 9ac392164735b..c8f43c9dcf2eb 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount.py @@ -27,8 +27,6 @@ `$ bin/spark-submit examples/src/main/python/sql/streaming/structured_network_wordcount.py localhost 9999` """ -from __future__ import print_function - import sys from pyspark.sql import SparkSession diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py index c4e3bbf44cd5a..cc39d8afa6be9 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py @@ -39,8 +39,6 @@ One recommended , pair is 10, 5 """ -from __future__ import print_function - import sys from pyspark.sql import SparkSession diff --git a/examples/src/main/python/status_api_demo.py b/examples/src/main/python/status_api_demo.py index 8cc8cc820cfce..7b408c87260c0 100644 --- a/examples/src/main/python/status_api_demo.py +++ b/examples/src/main/python/status_api_demo.py @@ -15,15 +15,10 @@ # limitations under the License. # -from __future__ import print_function - import time import threading import sys -if sys.version >= '3': - import queue as Queue -else: - import Queue +import queue as Queue from pyspark import SparkConf, SparkContext diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py index f9a5c43a8eaa9..fac07727b7b12 100644 --- a/examples/src/main/python/streaming/hdfs_wordcount.py +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -25,8 +25,6 @@ Then create a text file in `localdir` and the words in the file will get counted. """ -from __future__ import print_function - import sys from pyspark import SparkContext diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py index f3099d2517cd5..b57f4e9e38b82 100644 --- a/examples/src/main/python/streaming/network_wordcount.py +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -25,8 +25,6 @@ and then run the example `$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999` """ -from __future__ import print_function - import sys from pyspark import SparkContext diff --git a/examples/src/main/python/streaming/network_wordjoinsentiments.py b/examples/src/main/python/streaming/network_wordjoinsentiments.py index 2b5434c0c845a..5b03546fb4d83 100644 --- a/examples/src/main/python/streaming/network_wordjoinsentiments.py +++ b/examples/src/main/python/streaming/network_wordjoinsentiments.py @@ -30,8 +30,6 @@ localhost 9999` """ -from __future__ import print_function - import sys from pyspark import SparkContext diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py index a39c4d0b5b8cd..6ebe91a2f47fe 100644 --- a/examples/src/main/python/streaming/recoverable_network_wordcount.py +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -35,8 +35,6 @@ checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from the checkpoint data. """ -from __future__ import print_function - import os import sys @@ -45,10 +43,10 @@ # Get or register a Broadcast variable -def getWordBlacklist(sparkContext): - if ('wordBlacklist' not in globals()): - globals()['wordBlacklist'] = sparkContext.broadcast(["a", "b", "c"]) - return globals()['wordBlacklist'] +def getWordExcludeList(sparkContext): + if ('wordExcludeList' not in globals()): + globals()['wordExcludeList'] = sparkContext.broadcast(["a", "b", "c"]) + return globals()['wordExcludeList'] # Get or register an Accumulator @@ -74,14 +72,14 @@ def createContext(host, port, outputPath): wordCounts = words.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y) def echo(time, rdd): - # Get or register the blacklist Broadcast - blacklist = getWordBlacklist(rdd.context) + # Get or register the excludeList Broadcast + excludeList = getWordExcludeList(rdd.context) # Get or register the droppedWordsCounter Accumulator droppedWordsCounter = getDroppedWordsCounter(rdd.context) - # Use blacklist to drop words and use droppedWordsCounter to count them + # Use excludeList to drop words and use droppedWordsCounter to count them def filterFunc(wordCount): - if wordCount[0] in blacklist.value: + if wordCount[0] in excludeList.value: droppedWordsCounter.add(wordCount[1]) return False else: diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py index ab3cfc067994d..59a8a11a45b19 100644 --- a/examples/src/main/python/streaming/sql_network_wordcount.py +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -27,8 +27,6 @@ and then run the example `$ bin/spark-submit examples/src/main/python/streaming/sql_network_wordcount.py localhost 9999` """ -from __future__ import print_function - import sys from pyspark import SparkContext diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index d5d1eba6c5969..7a45be663a765 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -29,8 +29,6 @@ `$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \ localhost 9999` """ -from __future__ import print_function - import sys from pyspark import SparkContext diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py index 49551d40851cc..9f543daecd3dd 100755 --- a/examples/src/main/python/transitive_closure.py +++ b/examples/src/main/python/transitive_closure.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - import sys from random import Random diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py index a05e24ff3ff95..037c1e8aa379d 100755 --- a/examples/src/main/python/wordcount.py +++ b/examples/src/main/python/wordcount.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - import sys from operator import add diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index ec9b44ce6e3b7..cf03e0203f771 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -82,7 +82,7 @@ object SparkKMeans { while(tempDist > convergeDist) { val closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) - val pointStats = closest.reduceByKey{case ((p1, c1), (p2, c2)) => (p1 + p2, c1 + c2)} + val pointStats = closest.reduceByKey(mergeResults) val newPoints = pointStats.map {pair => (pair._1, pair._2._1 * (1.0 / pair._2._2))}.collectAsMap() @@ -102,5 +102,11 @@ object SparkKMeans { kPoints.foreach(println) spark.stop() } + + private def mergeResults( + a: (Vector[Double], Int), + b: (Vector[Double], Int)): (Vector[Double], Int) = { + (a._1 + b._1, a._2 + b._2) + } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 243c22e71275c..ee3bbe40fbeed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.{IntParam, LongAccumulator} /** * Use this singleton to get or register a Broadcast variable. */ -object WordBlacklist { +object WordExcludeList { @volatile private var instance: Broadcast[Seq[String]] = null @@ -40,8 +40,8 @@ object WordBlacklist { if (instance == null) { synchronized { if (instance == null) { - val wordBlacklist = Seq("a", "b", "c") - instance = sc.broadcast(wordBlacklist) + val wordExcludeList = Seq("a", "b", "c") + instance = sc.broadcast(wordExcludeList) } } } @@ -60,7 +60,7 @@ object DroppedWordsCounter { if (instance == null) { synchronized { if (instance == null) { - instance = sc.longAccumulator("WordsInBlacklistCounter") + instance = sc.longAccumulator("DroppedWordsCounter") } } } @@ -117,13 +117,13 @@ object RecoverableNetworkWordCount { val words = lines.flatMap(_.split(" ")) val wordCounts = words.map((_, 1)).reduceByKey(_ + _) wordCounts.foreachRDD { (rdd: RDD[(String, Int)], time: Time) => - // Get or register the blacklist Broadcast - val blacklist = WordBlacklist.getInstance(rdd.sparkContext) + // Get or register the excludeList Broadcast + val excludeList = WordExcludeList.getInstance(rdd.sparkContext) // Get or register the droppedWordsCounter Accumulator val droppedWordsCounter = DroppedWordsCounter.getInstance(rdd.sparkContext) - // Use blacklist to drop words and use droppedWordsCounter to count them + // Use excludeList to drop words and use droppedWordsCounter to count them val counts = rdd.filter { case (word, count) => - if (blacklist.value.contains(word)) { + if (excludeList.value.contains(word)) { droppedWordsCounter.add(count) false } else { diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index 3947d327dfac6..75690bb7722e3 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -85,7 +85,7 @@ object SchemaConverters { StructField(f.name, schemaType.dataType, schemaType.nullable) } - SchemaType(StructType(fields), nullable = false) + SchemaType(StructType(fields.toSeq), nullable = false) case ARRAY => val schemaType = toSqlTypeHelper(avroSchema.getElementType, existingRecordNames) @@ -126,7 +126,7 @@ object SchemaConverters { StructField(s"member$i", schemaType.dataType, nullable = true) } - SchemaType(StructType(fields), nullable = false) + SchemaType(StructType(fields.toSeq), nullable = false) } case other => throw new IncompatibleSchemaException(s"Unsupported type $other") diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index e2ae489446d85..83a7ef0061fb2 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -39,9 +39,10 @@ import org.apache.spark.sql.TestingUDT.IntervalData import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, LA, UTC} -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{FormattedMode, SparkPlan} import org.apache.spark.sql.execution.datasources.{DataSource, FilePartition} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy._ import org.apache.spark.sql.test.SharedSparkSession @@ -1808,7 +1809,7 @@ class AvroV1Suite extends AvroSuite { .set(SQLConf.USE_V1_SOURCE_LIST, "avro") } -class AvroV2Suite extends AvroSuite { +class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { import testImplicits._ override protected def sparkConf: SparkConf = @@ -1907,4 +1908,32 @@ class AvroV2Suite extends AvroSuite { assert(scan1.sameResult(scan2)) } } + + test("explain formatted on an avro data source v2") { + withTempDir { dir => + val basePath = dir.getCanonicalPath + "/avro" + val expected_plan_fragment = + s""" + |\\(1\\) BatchScan + |Output \\[2\\]: \\[value#xL, id#x\\] + |DataFilters: \\[isnotnull\\(value#xL\\), \\(value#xL > 2\\)\\] + |Format: avro + |Location: InMemoryFileIndex\\[.*\\] + |PartitionFilters: \\[isnotnull\\(id#x\\), \\(id#x > 1\\)\\] + |ReadSchema: struct\\ + |""".stripMargin.trim + spark.range(10) + .select(col("id"), col("id").as("value")) + .write.option("header", true) + .partitionBy("id") + .format("avro") + .save(basePath) + val df = spark + .read + .format("avro") + .load(basePath).where($"id" > 1 && $"value" > 2) + val normalizedOutput = getNormalizedExplain(df, FormattedMode) + assert(expected_plan_fragment.r.findAllMatchIn(normalizedOutput).length == 1) + } + } } diff --git a/external/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh b/external/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh index 00885a3b62327..343bc01651318 100755 --- a/external/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh +++ b/external/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh @@ -18,7 +18,7 @@ dpkg-divert --add /bin/systemctl && ln -sT /bin/true /bin/systemctl apt update -apt install -y mariadb-plugin-gssapi-server +apt install -y mariadb-plugin-gssapi-server=1:10.4.12+maria~bionic echo "gssapi_keytab_path=/docker-entrypoint-initdb.d/mariadb.keytab" >> /etc/mysql/mariadb.conf.d/auth_gssapi.cnf echo "gssapi_principal_name=mariadb/__IP_ADDRESS_REPLACE_ME__@EXAMPLE.COM" >> /etc/mysql/mariadb.conf.d/auth_gssapi.cnf docker-entrypoint.sh mysqld diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala index fda377e032350..5abca8df77dcd 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala @@ -36,7 +36,7 @@ private[spark] object DockerUtils { .orElse(findFromDockerMachine()) .orElse(Try(Seq("/bin/bash", "-c", "boot2docker ip 2>/dev/null").!!.trim).toOption) .getOrElse { - // This block of code is based on Utils.findLocalInetAddress(), but is modified to blacklist + // This block of code is based on Utils.findLocalInetAddress(), but is modified to exclude // certain interfaces. val address = InetAddress.getLocalHost // Address resolves to something like 127.0.1.1, which happens on Debian; try to find @@ -44,12 +44,12 @@ private[spark] object DockerUtils { // getNetworkInterfaces returns ifs in reverse order compared to ifconfig output order // on unix-like system. On windows, it returns in index order. // It's more proper to pick ip address following system output order. - val blackListedIFs = Seq( + val excludedIFs = Seq( "vboxnet0", // Mac "docker0" // Linux ) val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.asScala.toSeq.filter { i => - !blackListedIFs.contains(i.getName) + !excludedIFs.contains(i.getName) } val reOrderedNetworkIFs = activeNetworkIFs.reverse for (ni <- reOrderedNetworkIFs) { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 216e74a85c2ae..5ab7862674956 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -336,7 +336,7 @@ private[kafka010] class KafkaOffsetReader( } }) } - incorrectOffsets + incorrectOffsets.toSeq } // Retry to fetch latest offsets when detecting incorrect offsets. We don't use diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index bdad214a91343..ee31652eaf1f4 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -1540,8 +1540,8 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { makeSureGetOffsetCalled, Execute { q => // wait to reach the last offset in every partition - q.awaitOffset( - 0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L)), streamingTimeout.toMillis) + q.awaitOffset(0, + KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L).toMap), streamingTimeout.toMillis) }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala index 3e32b592b3a3a..ab6550ddf2fb3 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala @@ -241,7 +241,7 @@ object ConsumerStrategies { new Subscribe[K, V]( new ju.ArrayList(topics.asJavaCollection), new ju.HashMap[String, Object](kafkaParams.asJava), - new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).asJava)) + new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).toMap.asJava)) } /** @@ -320,7 +320,7 @@ object ConsumerStrategies { new SubscribePattern[K, V]( pattern, new ju.HashMap[String, Object](kafkaParams.asJava), - new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).asJava)) + new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).toMap.asJava)) } /** @@ -404,7 +404,7 @@ object ConsumerStrategies { new Assign[K, V]( new ju.ArrayList(topicPartitions.asJavaCollection), new ju.HashMap[String, Object](kafkaParams.asJava), - new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).asJava)) + new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).toMap.asJava)) } /** diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index a449a8bb7213e..fcdc92580ba35 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -70,7 +70,8 @@ private[spark] class DirectKafkaInputDStream[K, V]( @transient private var kc: Consumer[K, V] = null def consumer(): Consumer[K, V] = this.synchronized { if (null == kc) { - kc = consumerStrategy.onStart(currentOffsets.mapValues(l => java.lang.Long.valueOf(l)).asJava) + kc = consumerStrategy.onStart( + currentOffsets.mapValues(l => java.lang.Long.valueOf(l)).toMap.asJava) } kc } diff --git a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java index dc364aca9bd3b..3d6e5ebe978e8 100644 --- a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java +++ b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java @@ -48,15 +48,12 @@ public void testConsumerStrategyConstructors() { JavaConverters.mapAsScalaMapConverter(kafkaParams).asScala(); final Map offsets = new HashMap<>(); offsets.put(tp1, 23L); + final Map dummyOffsets = new HashMap<>(); + for (Map.Entry kv : offsets.entrySet()) { + dummyOffsets.put(kv.getKey(), kv.getValue()); + } final scala.collection.Map sOffsets = - JavaConverters.mapAsScalaMapConverter(offsets).asScala().mapValues( - new scala.runtime.AbstractFunction1() { - @Override - public Object apply(Long x) { - return (Object) x; - } - } - ); + JavaConverters.mapAsScalaMap(dummyOffsets); final ConsumerStrategy sub1 = ConsumerStrategies.Subscribe(sTopics, sKafkaParams, sOffsets); diff --git a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index 31ca2fe5c95ff..d704aeb507518 100644 --- a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -57,7 +57,7 @@ * Example: * # export AWS keys if necessary * $ export AWS_ACCESS_KEY_ID=[your-access-key] - * $ export AWS_SECRET_KEY= + * $ export AWS_SECRET_ACCESS_KEY= * * # run the example * $ SPARK_HOME/bin/run-example streaming.JavaKinesisWordCountASL myAppName mySparkStream \ @@ -68,7 +68,7 @@ * * This code uses the DefaultAWSCredentialsProviderChain to find credentials * in the following order: - * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY * Java System Properties - aws.accessKeyId and aws.secretKey * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs * Instance profile credentials - delivered through the Amazon EC2 metadata service diff --git a/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py b/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py index 777a33270c415..df8c64e531cfa 100644 --- a/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py +++ b/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py @@ -32,7 +32,7 @@ Example: # export AWS keys if necessary $ export AWS_ACCESS_KEY_ID= - $ export AWS_SECRET_KEY= + $ export AWS_SECRET_ACCESS_KEY= # run the example $ bin/spark-submit --jars \ @@ -45,7 +45,7 @@ This code uses the DefaultAWSCredentialsProviderChain to find credentials in the following order: - Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY Java System Properties - aws.accessKeyId and aws.secretKey Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs Instance profile credentials - delivered through the Amazon EC2 metadata service @@ -55,8 +55,6 @@ See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on the Kinesis Spark Streaming integration. """ -from __future__ import print_function - import sys from pyspark import SparkContext diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index 32f4a6759474f..bbb6008c2dddf 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -51,7 +51,7 @@ import org.apache.spark.streaming.kinesis.KinesisInputDStream * Example: * # export AWS keys if necessary * $ export AWS_ACCESS_KEY_ID= - * $ export AWS_SECRET_KEY= + * $ export AWS_SECRET_ACCESS_KEY= * * # run the example * $ SPARK_HOME/bin/run-example streaming.KinesisWordCountASL myAppName mySparkStream \ @@ -62,7 +62,7 @@ import org.apache.spark.streaming.kinesis.KinesisInputDStream * * This code uses the DefaultAWSCredentialsProviderChain to find credentials * in the following order: - * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY * Java System Properties - aws.accessKeyId and aws.secretKey * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs * Instance profile credentials - delivered through the Amazon EC2 metadata service diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 8815eb29bc860..3a02e2be6fe04 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -76,7 +76,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * @return fitted models, matching the input parameter maps */ @Since("2.0.0") - def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = { + def fit(dataset: Dataset[_], paramMaps: Seq[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala index 6ef42500f86f7..cc691d1c0c58c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala @@ -30,7 +30,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.{Vector => OldVector} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.sql._ import org.apache.spark.storage.StorageLevel /** @@ -212,14 +212,34 @@ class FMClassifier @Since("3.0.0") ( if (handlePersistence) data.persist(StorageLevel.MEMORY_AND_DISK) - val coefficients = trainImpl(data, numFeatures, LogisticLoss) + val (coefficients, objectiveHistory) = trainImpl(data, numFeatures, LogisticLoss) val (intercept, linear, factors) = splitCoefficients( coefficients, numFeatures, $(factorSize), $(fitIntercept), $(fitLinear)) if (handlePersistence) data.unpersist() - copyValues(new FMClassificationModel(uid, intercept, linear, factors)) + createModel(dataset, intercept, linear, factors, objectiveHistory) + } + + private def createModel( + dataset: Dataset[_], + intercept: Double, + linear: Vector, + factors: Matrix, + objectiveHistory: Array[Double]): FMClassificationModel = { + val model = copyValues(new FMClassificationModel(uid, intercept, linear, factors)) + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) + + val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() + val summary = new FMClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + weightColName, + objectiveHistory) + model.setSummary(Some(summary)) } @Since("3.0.0") @@ -243,7 +263,8 @@ class FMClassificationModel private[classification] ( @Since("3.0.0") val linear: Vector, @Since("3.0.0") val factors: Matrix) extends ProbabilisticClassificationModel[Vector, FMClassificationModel] - with FMClassifierParams with MLWritable { + with FMClassifierParams with MLWritable + with HasTrainingSummary[FMClassificationTrainingSummary]{ @Since("3.0.0") override val numClasses: Int = 2 @@ -251,6 +272,27 @@ class FMClassificationModel private[classification] ( @Since("3.0.0") override val numFeatures: Int = linear.size + /** + * Gets summary of model on training set. An exception is thrown + * if `hasSummary` is false. + */ + @Since("3.1.0") + override def summary: FMClassificationTrainingSummary = super.summary + + /** + * Evaluates the model on a test dataset. + * + * @param dataset Test dataset to evaluate model on. + */ + @Since("3.1.0") + def evaluate(dataset: Dataset[_]): FMClassificationSummary = { + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) + // Handle possible missing or invalid probability or prediction columns + val (summaryModel, probability, predictionColName) = findSummaryModel() + new FMClassificationSummaryImpl(summaryModel.transform(dataset), + probability, predictionColName, $(labelCol), weightColName) + } + @Since("3.0.0") override def predictRaw(features: Vector): Vector = { val rawPrediction = getRawPrediction(features, intercept, linear, factors) @@ -328,3 +370,53 @@ object FMClassificationModel extends MLReadable[FMClassificationModel] { } } } + +/** + * Abstraction for FMClassifier results for a given model. + */ +sealed trait FMClassificationSummary extends BinaryClassificationSummary + +/** + * Abstraction for FMClassifier training results. + */ +sealed trait FMClassificationTrainingSummary extends FMClassificationSummary with TrainingSummary + +/** + * FMClassifier results for a given model. + * + * @param predictions dataframe output by the model's `transform` method. + * @param scoreCol field in "predictions" which gives the probability of each instance. + * @param predictionCol field in "predictions" which gives the prediction for a data instance as a + * double. + * @param labelCol field in "predictions" which gives the true label of each instance. + * @param weightCol field in "predictions" which gives the weight of each instance. + */ +private class FMClassificationSummaryImpl( + @transient override val predictions: DataFrame, + override val scoreCol: String, + override val predictionCol: String, + override val labelCol: String, + override val weightCol: String) + extends FMClassificationSummary + +/** + * FMClassifier training results. + * + * @param predictions dataframe output by the model's `transform` method. + * @param scoreCol field in "predictions" which gives the probability of each instance. + * @param predictionCol field in "predictions" which gives the prediction for a data instance as a + * double. + * @param labelCol field in "predictions" which gives the true label of each instance. + * @param weightCol field in "predictions" which gives the weight of each instance. + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. + */ +private class FMClassificationTrainingSummaryImpl( + predictions: DataFrame, + scoreCol: String, + predictionCol: String, + labelCol: String, + weightCol: String, + override val objectiveHistory: Array[Double]) + extends FMClassificationSummaryImpl( + predictions, scoreCol, predictionCol, labelCol, weightCol) + with FMClassificationTrainingSummary diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 18fd220b4ca9c..90845021fc073 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -492,12 +492,7 @@ class GaussianMixture @Since("2.0.0") ( (i, (agg.means(i), agg.covs(i), agg.weights(i), ws)) } } else Iterator.empty - }.reduceByKey { case ((mean1, cov1, w1, ws1), (mean2, cov2, w2, ws2)) => - // update the weights, means and covariances for i-th distributions - BLAS.axpy(1.0, mean2, mean1) - BLAS.axpy(1.0, cov2, cov1) - (mean1, cov1, w1 + w2, ws1 + ws2) - }.mapValues { case (mean, cov, w, ws) => + }.reduceByKey(GaussianMixture.mergeWeightsMeans).mapValues { case (mean, cov, w, ws) => // Create new distributions based on the partial assignments // (often referred to as the "M" step in literature) GaussianMixture.updateWeightsAndGaussians(mean, cov, w, ws) @@ -560,12 +555,7 @@ class GaussianMixture @Since("2.0.0") ( agg.meanIter.zip(agg.covIter).zipWithIndex .map { case ((mean, cov), i) => (i, (mean, cov, agg.weights(i), ws)) } } else Iterator.empty - }.reduceByKey { case ((mean1, cov1, w1, ws1), (mean2, cov2, w2, ws2)) => - // update the weights, means and covariances for i-th distributions - BLAS.axpy(1.0, mean2, mean1) - BLAS.axpy(1.0, cov2, cov1) - (mean1, cov1, w1 + w2, ws1 + ws2) - }.mapValues { case (mean, cov, w, ws) => + }.reduceByKey(GaussianMixture.mergeWeightsMeans).mapValues { case (mean, cov, w, ws) => // Create new distributions based on the partial assignments // (often referred to as the "M" step in literature) GaussianMixture.updateWeightsAndGaussians(mean, cov, w, ws) @@ -624,8 +614,8 @@ class GaussianMixture @Since("2.0.0") ( val gaussians = Array.tabulate(numClusters) { i => val start = i * numSamples val end = start + numSamples - val sampleSlice = samples.view(start, end) - val weightSlice = sampleWeights.view(start, end) + val sampleSlice = samples.view.slice(start, end) + val weightSlice = sampleWeights.view.slice(start, end) val localWeightSum = weightSlice.sum weights(i) = localWeightSum / weightSum @@ -691,6 +681,16 @@ object GaussianMixture extends DefaultParamsReadable[GaussianMixture] { new DenseMatrix(n, n, symmetricValues) } + private def mergeWeightsMeans( + a: (DenseVector, DenseVector, Double, Double), + b: (DenseVector, DenseVector, Double, Double)): (DenseVector, DenseVector, Double, Double) = + { + // update the weights, means and covariances for i-th distributions + BLAS.axpy(1.0, b._1, a._1) + BLAS.axpy(1.0, b._2, a._2) + (a._1, a._2, a._3 + b._3, a._4 + b._4) + } + /** * Update the weight, mean and covariance of gaussian distribution. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala index bd9be779fedbd..72ab3dbc31016 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala @@ -201,7 +201,7 @@ object RobustScaler extends DefaultParamsReadable[RobustScaler] { } Iterator.tabulate(numFeatures)(i => (i, summaries(i).compress)) } else Iterator.empty - }.reduceByKey { case (s1, s2) => s1.merge(s2) } + }.reduceByKey { (s1, s2) => s1.merge(s2) } } else { val scale = math.max(math.ceil(math.sqrt(vectors.getNumPartitions)).toInt, 2) vectors.mapPartitionsWithIndex { case (pid, iter) => @@ -214,7 +214,7 @@ object RobustScaler extends DefaultParamsReadable[RobustScaler] { seqOp = (s, v) => s.insert(v), combOp = (s1, s2) => s1.compress.merge(s2.compress) ).map { case ((_, i), s) => (i, s) - }.reduceByKey { case (s1, s2) => s1.compress.merge(s2.compress) } + }.reduceByKey { (s1, s2) => s1.compress.merge(s2.compress) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index bbfcbfbe038ef..db2665fa2e4a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -291,7 +291,7 @@ class Word2VecModel private[ml] ( val outputSchema = transformSchema(dataset.schema, logging = true) val vectors = wordVectors.getVectors .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) - .map(identity) // mapValues doesn't return a serializable map (SI-7005) + .map(identity).toMap // mapValues doesn't return a serializable map (SI-7005) val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors) val d = $(vectorSize) val emptyVec = Vectors.sparse(d, Array.emptyIntArray, Array.emptyDoubleArray) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 53ca35ccd0073..f12c1f995b7d7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -937,7 +937,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** Put param pairs with a `java.util.List` of values for Python. */ private[ml] def put(paramPairs: JList[ParamPair[_]]): this.type = { - put(paramPairs.asScala: _*) + put(paramPairs.asScala.toSeq: _*) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala index df4dac1e240e2..84c0985245a2e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala @@ -47,7 +47,7 @@ import org.apache.spark.storage.StorageLevel */ private[ml] trait FactorizationMachinesParams extends PredictorParams with HasMaxIter with HasStepSize with HasTol with HasSolver with HasSeed - with HasFitIntercept with HasRegParam { + with HasFitIntercept with HasRegParam with HasWeightCol { /** * Param for dimensionality of the factors (>= 0) @@ -112,6 +112,10 @@ private[ml] trait FactorizationMachinesParams extends PredictorParams "The solver algorithm for optimization. Supported options: " + s"${supportedSolvers.mkString(", ")}. (Default adamW)", ParamValidators.inArray[String](supportedSolvers)) + + setDefault(factorSize -> 8, fitIntercept -> true, fitLinear -> true, regParam -> 0.0, + miniBatchFraction -> 1.0, initStd -> 0.01, maxIter -> 100, stepSize -> 1.0, tol -> 1E-6, + solver -> AdamW) } private[ml] trait FactorizationMachines extends FactorizationMachinesParams { @@ -130,7 +134,7 @@ private[ml] trait FactorizationMachines extends FactorizationMachinesParams { data: RDD[(Double, OldVector)], numFeatures: Int, loss: String - ): Vector = { + ): (Vector, Array[Double]) = { // initialize coefficients val initialCoefficients = initCoefficients(numFeatures) @@ -147,8 +151,8 @@ private[ml] trait FactorizationMachines extends FactorizationMachinesParams { .setRegParam($(regParam)) .setMiniBatchFraction($(miniBatchFraction)) .setConvergenceTol($(tol)) - val coefficients = optimizer.optimize(data, initialCoefficients) - coefficients.asML + val (coefficients, lossHistory) = optimizer.optimizeWithLossReturned(data, initialCoefficients) + (coefficients.asML, lossHistory) } } @@ -308,7 +312,6 @@ class FMRegressor @Since("3.0.0") ( */ @Since("3.0.0") def setFactorSize(value: Int): this.type = set(factorSize, value) - setDefault(factorSize -> 8) /** * Set whether to fit intercept term. @@ -318,7 +321,6 @@ class FMRegressor @Since("3.0.0") ( */ @Since("3.0.0") def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) - setDefault(fitIntercept -> true) /** * Set whether to fit linear term. @@ -328,7 +330,6 @@ class FMRegressor @Since("3.0.0") ( */ @Since("3.0.0") def setFitLinear(value: Boolean): this.type = set(fitLinear, value) - setDefault(fitLinear -> true) /** * Set the L2 regularization parameter. @@ -338,7 +339,6 @@ class FMRegressor @Since("3.0.0") ( */ @Since("3.0.0") def setRegParam(value: Double): this.type = set(regParam, value) - setDefault(regParam -> 0.0) /** * Set the mini-batch fraction parameter. @@ -348,7 +348,6 @@ class FMRegressor @Since("3.0.0") ( */ @Since("3.0.0") def setMiniBatchFraction(value: Double): this.type = set(miniBatchFraction, value) - setDefault(miniBatchFraction -> 1.0) /** * Set the standard deviation of initial coefficients. @@ -358,7 +357,6 @@ class FMRegressor @Since("3.0.0") ( */ @Since("3.0.0") def setInitStd(value: Double): this.type = set(initStd, value) - setDefault(initStd -> 0.01) /** * Set the maximum number of iterations. @@ -368,7 +366,6 @@ class FMRegressor @Since("3.0.0") ( */ @Since("3.0.0") def setMaxIter(value: Int): this.type = set(maxIter, value) - setDefault(maxIter -> 100) /** * Set the initial step size for the first step (like learning rate). @@ -378,7 +375,6 @@ class FMRegressor @Since("3.0.0") ( */ @Since("3.0.0") def setStepSize(value: Double): this.type = set(stepSize, value) - setDefault(stepSize -> 1.0) /** * Set the convergence tolerance of iterations. @@ -388,7 +384,6 @@ class FMRegressor @Since("3.0.0") ( */ @Since("3.0.0") def setTol(value: Double): this.type = set(tol, value) - setDefault(tol -> 1E-6) /** * Set the solver algorithm used for optimization. @@ -399,7 +394,6 @@ class FMRegressor @Since("3.0.0") ( */ @Since("3.0.0") def setSolver(value: String): this.type = set(solver, value) - setDefault(solver -> AdamW) /** * Set the random seed for weight initialization. @@ -427,7 +421,7 @@ class FMRegressor @Since("3.0.0") ( if (handlePersistence) data.persist(StorageLevel.MEMORY_AND_DISK) - val coefficients = trainImpl(data, numFeatures, SquaredError) + val (coefficients, _) = trainImpl(data, numFeatures, SquaredError) val (intercept, linear, factors) = splitCoefficients( coefficients, numFeatures, $(factorSize), $(fitIntercept), $(fitLinear)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 0ee895a95a288..f7dfda81d4e6f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -181,6 +181,9 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam s"${supportedSolvers.mkString(", ")}. (Default irls)", ParamValidators.inArray[String](supportedSolvers)) + setDefault(family -> Gaussian.name, variancePower -> 0.0, maxIter -> 25, tol -> 1E-6, + regParam -> 0.0, solver -> IRLS) + @Since("2.0.0") override def validateAndTransformSchema( schema: StructType, @@ -257,7 +260,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val */ @Since("2.0.0") def setFamily(value: String): this.type = set(family, value) - setDefault(family -> Gaussian.name) /** * Sets the value of param [[variancePower]]. @@ -268,7 +270,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val */ @Since("2.2.0") def setVariancePower(value: Double): this.type = set(variancePower, value) - setDefault(variancePower -> 0.0) /** * Sets the value of param [[linkPower]]. @@ -305,7 +306,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val */ @Since("2.0.0") def setMaxIter(value: Int): this.type = set(maxIter, value) - setDefault(maxIter -> 25) /** * Sets the convergence tolerance of iterations. @@ -316,7 +316,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val */ @Since("2.0.0") def setTol(value: Double): this.type = set(tol, value) - setDefault(tol -> 1E-6) /** * Sets the regularization parameter for L2 regularization. @@ -332,7 +331,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val */ @Since("2.0.0") def setRegParam(value: Double): this.type = set(regParam, value) - setDefault(regParam -> 0.0) /** * Sets the value of param [[weightCol]]. @@ -364,7 +362,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val */ @Since("2.0.0") def setSolver(value: String): this.type = set(solver, value) - setDefault(solver -> IRLS) /** * Sets the link prediction (linear predictor) column name. @@ -1220,10 +1217,41 @@ class GeneralizedLinearRegressionSummary private[regression] ( private[regression] lazy val link: Link = familyLink.link + /** + * summary row containing: + * numInstances, weightSum, deviance, rss, weighted average of label - offset. + */ + private lazy val glrSummary = { + val devUDF = udf { (label: Double, pred: Double, weight: Double) => + family.deviance(label, pred, weight) + } + val devCol = sum(devUDF(label, prediction, weight)) + + val rssCol = if (model.getFamily.toLowerCase(Locale.ROOT) != Binomial.name && + model.getFamily.toLowerCase(Locale.ROOT) != Poisson.name) { + val rssUDF = udf { (label: Double, pred: Double, weight: Double) => + (label - pred) * (label - pred) * weight / family.variance(pred) + } + sum(rssUDF(label, prediction, weight)) + } else { + lit(Double.NaN) + } + + val avgCol = if (model.getFitIntercept && + (!model.hasOffsetCol || (model.hasOffsetCol && family == Gaussian && link == Identity))) { + sum((label - offset) * weight) / sum(weight) + } else { + lit(Double.NaN) + } + + predictions + .select(count(label), sum(weight), devCol, rssCol, avgCol) + .head() + } + /** Number of instances in DataFrame predictions. */ @Since("2.2.0") - lazy val numInstances: Long = predictions.count() - + lazy val numInstances: Long = glrSummary.getLong(0) /** * Name of features. If the name cannot be retrieved from attributes, @@ -1335,9 +1363,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ if (!model.hasOffsetCol || (model.hasOffsetCol && family == Gaussian && link == Identity)) { - val agg = predictions.agg(sum(weight.multiply( - label.minus(offset))), sum(weight)).first() - link.link(agg.getDouble(0) / agg.getDouble(1)) + link.link(glrSummary.getDouble(4)) } else { // Create empty feature column and fit intercept only model using param setting from model val featureNull = "feature_" + java.util.UUID.randomUUID.toString @@ -1362,12 +1388,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( * The deviance for the fitted model. */ @Since("2.0.0") - lazy val deviance: Double = { - predictions.select(label, prediction, weight).rdd.map { - case Row(label: Double, pred: Double, weight: Double) => - family.deviance(label, pred, weight) - }.sum() - } + lazy val deviance: Double = glrSummary.getDouble(2) /** * The dispersion of the fitted model. @@ -1381,14 +1402,14 @@ class GeneralizedLinearRegressionSummary private[regression] ( model.getFamily.toLowerCase(Locale.ROOT) == Poisson.name) { 1.0 } else { - val rss = pearsonResiduals.agg(sum(pow(col("pearsonResiduals"), 2.0))).first().getDouble(0) + val rss = glrSummary.getDouble(3) rss / degreesOfFreedom } /** Akaike Information Criterion (AIC) for the fitted model. */ @Since("2.0.0") lazy val aic: Double = { - val weightSum = predictions.select(weight).agg(sum(weight)).first().getDouble(0) + val weightSum = glrSummary.getDouble(1) val t = predictions.select( label, prediction, weight).rdd.map { case Row(label: Double, pred: Double, weight: Double) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index d9f09c097292a..de559142a9261 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -1037,7 +1037,7 @@ class LinearRegressionSummary private[regression] ( } /** Number of instances in DataFrame predictions */ - lazy val numInstances: Long = predictions.count() + lazy val numInstances: Long = metrics.count /** Degrees of freedom */ @Since("2.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 259ecb3a1762f..68f6ed4281dea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1223,28 +1223,28 @@ private[python] class PythonMLLibAPI extends Serializable { * Python-friendly version of [[MLUtils.convertVectorColumnsToML()]]. */ def convertVectorColumnsToML(dataset: DataFrame, cols: JArrayList[String]): DataFrame = { - MLUtils.convertVectorColumnsToML(dataset, cols.asScala: _*) + MLUtils.convertVectorColumnsToML(dataset, cols.asScala.toSeq: _*) } /** * Python-friendly version of [[MLUtils.convertVectorColumnsFromML()]] */ def convertVectorColumnsFromML(dataset: DataFrame, cols: JArrayList[String]): DataFrame = { - MLUtils.convertVectorColumnsFromML(dataset, cols.asScala: _*) + MLUtils.convertVectorColumnsFromML(dataset, cols.asScala.toSeq: _*) } /** * Python-friendly version of [[MLUtils.convertMatrixColumnsToML()]]. */ def convertMatrixColumnsToML(dataset: DataFrame, cols: JArrayList[String]): DataFrame = { - MLUtils.convertMatrixColumnsToML(dataset, cols.asScala: _*) + MLUtils.convertMatrixColumnsToML(dataset, cols.asScala.toSeq: _*) } /** * Python-friendly version of [[MLUtils.convertMatrixColumnsFromML()]] */ def convertMatrixColumnsFromML(dataset: DataFrame, cols: JArrayList[String]): DataFrame = { - MLUtils.convertMatrixColumnsFromML(dataset, cols.asScala: _*) + MLUtils.convertMatrixColumnsFromML(dataset, cols.asScala.toSeq: _*) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 7c12697be95c8..99c6e8b3e079b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -225,7 +225,7 @@ class BisectingKMeans private ( divisibleIndices.contains(parentIndex(index)) } newClusters = summarize(d, newAssignments, dMeasure) - newClusterCenters = newClusters.mapValues(_.center).map(identity) + newClusterCenters = newClusters.mapValues(_.center).map(identity).toMap } if (preIndices != null) { preIndices.unpersist() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 4d98ba41bbb7b..d5a7882614546 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -17,8 +17,6 @@ package org.apache.spark.mllib.clustering -import scala.collection.mutable.IndexedSeq - import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV} import org.apache.spark.annotation.Since @@ -189,8 +187,8 @@ class GaussianMixture private ( case None => val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => - val slice = samples.view(i * nSamples, (i + 1) * nSamples) - new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) + val slice = samples.view.slice(i * nSamples, (i + 1) * nSamples) + new MultivariateGaussian(vectorMean(slice.toSeq), initCovariance(slice.toSeq)) }) } @@ -259,7 +257,7 @@ class GaussianMixture private ( } /** Average of dense breeze vectors */ - private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = { + private def vectorMean(x: Seq[BV[Double]]): BDV[Double] = { val v = BDV.zeros[Double](x(0).length) x.foreach(xi => v += xi) v / x.length.toDouble @@ -269,7 +267,7 @@ class GaussianMixture private ( * Construct matrix where diagonal entries are element-wise * variance of input vectors (computes biased variance) */ - private def initCovariance(x: IndexedSeq[BV[Double]]): BreezeMatrix[Double] = { + private def initCovariance(x: Seq[BV[Double]]): BreezeMatrix[Double] = { val mu = vectorMean(x) val ss = BDV.zeros[Double](x(0).length) x.foreach { xi => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index b697d2746ce7b..7938427544bd9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -131,4 +131,6 @@ class RegressionMetrics @Since("2.0.0") ( 1 - SSerr / SStot } } + + private[spark] def count: Long = summary.count } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index ac2b576f4ac4e..de3209c34bf07 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -335,7 +335,7 @@ object PrefixSpan extends Logging { largePrefixes = newLargePrefixes } - var freqPatterns = sc.parallelize(localFreqPatterns, 1) + var freqPatterns = sc.parallelize(localFreqPatterns.toSeq, 1) val numSmallPrefixes = smallPrefixes.size logInfo(s"number of small prefixes for local processing: $numSmallPrefixes") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 1336ffd2f7d5e..796a787e77db4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -129,7 +129,20 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va * @return solution vector */ def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = { - val (weights, _) = GradientDescent.runMiniBatchSGD( + val (weights, _) = optimizeWithLossReturned(data, initialWeights) + weights + } + + /** + * Runs gradient descent on the given training data. + * @param data training data + * @param initialWeights initial weights + * @return solution vector and loss value in an array + */ + def optimizeWithLossReturned( + data: RDD[(Double, Vector)], + initialWeights: Vector): (Vector, Array[Double]) = { + GradientDescent.runMiniBatchSGD( data, gradient, updater, @@ -139,7 +152,6 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va miniBatchFraction, initialWeights, convergenceTol) - weights } } @@ -195,7 +207,7 @@ object GradientDescent extends Logging { s"numIterations=$numIterations and miniBatchFraction=$miniBatchFraction") } - val stochasticLossHistory = new ArrayBuffer[Double](numIterations) + val stochasticLossHistory = new ArrayBuffer[Double](numIterations + 1) // Record previous weight and current one to calculate solution vector difference var previousWeights: Option[Vector] = None @@ -226,7 +238,7 @@ object GradientDescent extends Logging { var converged = false // indicates whether converged based on convergenceTol var i = 1 - while (!converged && i <= numIterations) { + while (!converged && (i <= numIterations + 1)) { val bcWeights = data.context.broadcast(weights) // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) @@ -249,17 +261,19 @@ object GradientDescent extends Logging { * and regVal is the regularization value computed in the previous iteration as well. */ stochasticLossHistory += lossSum / miniBatchSize + regVal - val update = updater.compute( - weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), - stepSize, i, regParam) - weights = update._1 - regVal = update._2 - - previousWeights = currentWeights - currentWeights = Some(weights) - if (previousWeights != None && currentWeights != None) { - converged = isConverged(previousWeights.get, - currentWeights.get, convergenceTol) + if (i != (numIterations + 1)) { + val update = updater.compute( + weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), + stepSize, i, regParam) + weights = update._1 + regVal = update._2 + + previousWeights = currentWeights + currentWeights = Some(weights) + if (previousWeights != None && currentWeights != None) { + converged = isConverged(previousWeights.get, + currentWeights.get, convergenceTol) + } } } else { logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero") @@ -271,7 +285,6 @@ object GradientDescent extends Logging { stochasticLossHistory.takeRight(10).mkString(", "))) (weights, stochasticLossHistory.toArray) - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 1ee9241104f87..4fc297560c088 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -136,7 +136,14 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) } override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = { - val (weights, _) = LBFGS.runLBFGS( + val (weights, _) = optimizeWithLossReturned(data, initialWeights) + weights + } + + def optimizeWithLossReturned( + data: RDD[(Double, Vector)], + initialWeights: Vector): (Vector, Array[Double]) = { + LBFGS.runLBFGS( data, gradient, updater, @@ -145,9 +152,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) maxNumIterations, regParam, initialWeights) - weights } - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index 365b2a06110f6..c669ced61d2f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -97,7 +97,7 @@ class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int } if (sizes(i) + tail.length >= offset + windowSize) { partitions += - new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail, offset) + new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail.toSeq, offset) partitionIndex += 1 } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 6e2732f7ae7aa..c3bda99786310 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -112,7 +112,7 @@ private[spark] class EntropyAggregator(numClasses: Int) * @param offset Start index of stats for this (node, feature, bin). */ def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = { - new EntropyCalculator(allStats.view(offset, offset + statsSize - 1).toArray, + new EntropyCalculator(allStats.view.slice(offset, offset + statsSize - 1).toArray, allStats(offset + statsSize - 1).toLong) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 5983118c05754..70163b56408a4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -107,7 +107,7 @@ private[spark] class GiniAggregator(numClasses: Int) * @param offset Start index of stats for this (node, feature, bin). */ def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = { - new GiniCalculator(allStats.view(offset, offset + statsSize - 1).toArray, + new GiniCalculator(allStats.view.slice(offset, offset + statsSize - 1).toArray, allStats(offset + statsSize - 1).toLong) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index f5b2f8d514c7e..7143fd07d7333 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -95,7 +95,7 @@ private[spark] class VarianceAggregator() * @param offset Start index of stats for this (node, feature, bin). */ def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = { - new VarianceCalculator(allStats.view(offset, offset + statsSize - 1).toArray, + new VarianceCalculator(allStats.view.slice(offset, offset + statsSize - 1).toArray, allStats(offset + statsSize - 1).toLong) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala index 2c613348c2d92..959e54e4c7169 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala @@ -85,10 +85,10 @@ private[mllib] object NumericParser { while (parsing && tokenizer.hasMoreTokens()) { token = tokenizer.nextToken() if (token == "(") { - items.append(parseTuple(tokenizer)) + items += parseTuple(tokenizer) allowComma = true } else if (token == "[") { - items.append(parseArray(tokenizer)) + items += parseArray(tokenizer) allowComma = true } else if (token == ",") { if (allowComma) { @@ -102,14 +102,14 @@ private[mllib] object NumericParser { // ignore whitespaces between delim chars, e.g. ", [" } else { // expecting a number - items.append(parseDouble(token)) + items += parseDouble(token) allowComma = true } } if (parsing) { throw new SparkException(s"A tuple must end with ')'.") } - items + items.toSeq } private def parseDouble(s: String): Double = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/FMClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/FMClassifierSuite.scala index d477049824b19..9a04bdc39718c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/FMClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/FMClassifierSuite.scala @@ -194,6 +194,32 @@ class FMClassifierSuite extends MLTest with DefaultReadWriteTest { testPredictionModelSinglePrediction(fmModel, smallBinaryDataset) } + test("summary and training summary") { + val fm = new FMClassifier() + val model = fm.setMaxIter(5).fit(smallBinaryDataset) + + val summary = model.evaluate(smallBinaryDataset) + + assert(model.summary.accuracy === summary.accuracy) + assert(model.summary.weightedPrecision === summary.weightedPrecision) + assert(model.summary.weightedRecall === summary.weightedRecall) + assert(model.summary.pr.collect() === summary.pr.collect()) + assert(model.summary.roc.collect() === summary.roc.collect()) + assert(model.summary.areaUnderROC === summary.areaUnderROC) + } + + test("FMClassifier training summary totalIterations") { + Seq(1, 5, 10, 20, 100).foreach { maxIter => + val trainer = new FMClassifier().setMaxIter(maxIter) + val model = trainer.fit(smallBinaryDataset) + if (maxIter == 1) { + assert(model.summary.totalIterations === maxIter) + } else { + assert(model.summary.totalIterations <= maxIter) + } + } + } + test("read/write") { def checkModelData( model: FMClassificationModel, diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index debd0dd65d0c8..04b20d1e58dd3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -219,7 +219,7 @@ class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest { model1.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) - val df2 = spark.createDataFrame(spark.sparkContext.parallelize(Array( + val df2 = spark.createDataFrame(spark.sparkContext.parallelize(Seq( (Vectors.dense(1.0, 1.0), 2.0), (Vectors.dense(10.0, 10.0), 2.0), (Vectors.dense(1.0, 0.5), 2.0), (Vectors.dense(10.0, 4.4), 2.0), (Vectors.dense(-1.0, 1.0), 2.0), (Vectors.dense(-100.0, 90.0), 2.0)))) @@ -286,7 +286,7 @@ class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest { model1.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) - val df2 = spark.createDataFrame(spark.sparkContext.parallelize(Array( + val df2 = spark.createDataFrame(spark.sparkContext.parallelize(Seq( (Vectors.dense(1.0, 1.0), 1.0), (Vectors.dense(10.0, 10.0), 2.0), (Vectors.dense(1.0, 0.5), 2.0), (Vectors.dense(10.0, 4.4), 3.0), (Vectors.dense(-1.0, 1.0), 3.0), (Vectors.dense(-100.0, 90.0), 4.0)))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 584594436267f..61f4359d99ea9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -255,7 +255,7 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes } test("compare with weightCol and without weightCol") { - val df1 = spark.createDataFrame(spark.sparkContext.parallelize(Array( + val df1 = spark.createDataFrame(spark.sparkContext.parallelize(Seq( Vectors.dense(1.0, 1.0), Vectors.dense(10.0, 10.0), Vectors.dense(10.0, 10.0), Vectors.dense(1.0, 0.5), @@ -285,7 +285,7 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes model1.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) - val df2 = spark.createDataFrame(spark.sparkContext.parallelize(Array( + val df2 = spark.createDataFrame(spark.sparkContext.parallelize(Seq( (Vectors.dense(1.0, 1.0), 1.0), (Vectors.dense(10.0, 10.0), 2.0), (Vectors.dense(1.0, 0.5), 1.0), @@ -322,7 +322,7 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes test("Two centers with weightCol") { // use the same weight for all samples. - val df1 = spark.createDataFrame(spark.sparkContext.parallelize(Array( + val df1 = spark.createDataFrame(spark.sparkContext.parallelize(Seq( (Vectors.dense(0.0, 0.0), 2.0), (Vectors.dense(0.0, 0.1), 2.0), (Vectors.dense(0.1, 0.0), 2.0), @@ -366,7 +366,7 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes assert(model1.clusterCenters(1) === model1_center2) // use different weight - val df2 = spark.createDataFrame(spark.sparkContext.parallelize(Array( + val df2 = spark.createDataFrame(spark.sparkContext.parallelize(Seq( (Vectors.dense(0.0, 0.0), 1.0), (Vectors.dense(0.0, 0.1), 2.0), (Vectors.dense(0.1, 0.0), 3.0), @@ -412,7 +412,7 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes test("Four centers with weightCol") { // no weight - val df1 = spark.createDataFrame(spark.sparkContext.parallelize(Array( + val df1 = spark.createDataFrame(spark.sparkContext.parallelize(Seq( Vectors.dense(0.1, 0.1), Vectors.dense(5.0, 0.2), Vectors.dense(10.0, 0.0), @@ -444,7 +444,7 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes model1.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) // use same weight, should have the same result as no weight - val df2 = spark.createDataFrame(spark.sparkContext.parallelize(Array( + val df2 = spark.createDataFrame(spark.sparkContext.parallelize(Seq( (Vectors.dense(0.1, 0.1), 2.0), (Vectors.dense(5.0, 0.2), 2.0), (Vectors.dense(10.0, 0.0), 2.0), diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index d4c620adc2e3c..06f2cb2b9788b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -189,7 +189,7 @@ class ClusteringEvaluatorSuite } test("single-element clusters with weight") { - val singleItemClusters = spark.createDataFrame(spark.sparkContext.parallelize(Array( + val singleItemClusters = spark.createDataFrame(spark.sparkContext.parallelize(Seq( (0.0, Vectors.dense(5.1, 3.5, 1.4, 0.2), 6.0), (1.0, Vectors.dense(7.0, 3.2, 4.7, 1.4), 0.25), (2.0, Vectors.dense(6.3, 3.3, 6.0, 2.5), 9.99)))).toDF("label", "features", "weight") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index d97df0050d74e..1c602cd7d9a4f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -29,14 +29,14 @@ class NormalizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ - @transient var data: Array[Vector] = _ - @transient var l1Normalized: Array[Vector] = _ - @transient var l2Normalized: Array[Vector] = _ + @transient var data: Seq[Vector] = _ + @transient var l1Normalized: Seq[Vector] = _ + @transient var l2Normalized: Seq[Vector] = _ override def beforeAll(): Unit = { super.beforeAll() - data = Array( + data = Seq( Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), Vectors.dense(0.0, 0.0, 0.0), Vectors.dense(0.6, -1.1, -3.0), @@ -44,7 +44,7 @@ class NormalizerSuite extends MLTest with DefaultReadWriteTest { Vectors.sparse(3, Seq((0, 5.7), (1, 0.72), (2, 2.7))), Vectors.sparse(3, Seq()) ) - l1Normalized = Array( + l1Normalized = Seq( Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))), Vectors.dense(0.0, 0.0, 0.0), Vectors.dense(0.12765957, -0.23404255, -0.63829787), @@ -52,7 +52,7 @@ class NormalizerSuite extends MLTest with DefaultReadWriteTest { Vectors.dense(0.625, 0.07894737, 0.29605263), Vectors.sparse(3, Seq()) ) - l2Normalized = Array( + l2Normalized = Seq( Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))), Vectors.dense(0.0, 0.0, 0.0), Vectors.dense(0.184549876, -0.3383414, -0.922749378), diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 9029fc96b36a8..28275eb06cf0d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -307,7 +307,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { } logInfo(s"Generated an explicit feedback dataset with ${training.size} ratings for training " + s"and ${test.size} for test.") - (sc.parallelize(training, 2), sc.parallelize(test, 2)) + (sc.parallelize(training.toSeq, 2), sc.parallelize(test.toSeq, 2)) } /** @@ -810,7 +810,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { val topItems = model.recommendForAllUsers(k) assert(topItems.count() == numUsers) assert(topItems.columns.contains("user")) - checkRecommendations(topItems, expectedUpToN, "item") + checkRecommendations(topItems, expectedUpToN.toMap, "item") } } @@ -831,7 +831,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { val topUsers = getALSModel.recommendForAllItems(k) assert(topUsers.count() == numItems) assert(topUsers.columns.contains("item")) - checkRecommendations(topUsers, expectedUpToN, "user") + checkRecommendations(topUsers, expectedUpToN.toMap, "user") } } @@ -853,7 +853,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { val topItems = model.recommendForUserSubset(userSubset, k) assert(topItems.count() == numUsersSubset) assert(topItems.columns.contains("user")) - checkRecommendations(topItems, expectedUpToN, "item") + checkRecommendations(topItems, expectedUpToN.toMap, "item") } } @@ -875,7 +875,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { val topUsers = model.recommendForItemSubset(itemSubset, k) assert(topUsers.count() == numItemsSubset) assert(topUsers.columns.contains("item")) - checkRecommendations(topUsers, expectedUpToN, "user") + checkRecommendations(topUsers, expectedUpToN.toMap, "user") } } @@ -1211,6 +1211,6 @@ object ALSSuite extends Logging { } logInfo(s"Generated an implicit feedback dataset with ${training.size} ratings for training " + s"and ${test.size} for test.") - (sc.parallelize(training, 2), sc.parallelize(test, 2)) + (sc.parallelize(training.toSeq, 2), sc.parallelize(test.toSeq, 2)) } } diff --git a/pom.xml b/pom.xml index 08ca13bfe9d37..cfcb55b27fa99 100644 --- a/pom.xml +++ b/pom.xml @@ -669,7 +669,7 @@ com.github.luben zstd-jni - 1.4.5-2 + 1.4.5-4 com.clearspring.analytics @@ -2594,6 +2594,12 @@ lib_managed + + metastore_db + + + spark-warehouse + @@ -3159,7 +3165,7 @@ scala-2.13 - 2.13.1 + 2.13.3 2.13 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 04a3fc4b63050..d19b514d662fa 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -444,10 +444,10 @@ object SparkBuild extends PomBuild { object SparkParallelTestGrouping { // Settings for parallelizing tests. The basic strategy here is to run the slowest suites (or // collections of suites) in their own forked JVMs, allowing us to gain parallelism within a - // SBT project. Here, we take a whitelisting approach where the default behavior is to run all + // SBT project. Here, we take an opt-in approach where the default behavior is to run all // tests sequentially in a single JVM, requiring us to manually opt-in to the extra parallelism. // - // There are a reasons why such a whitelist approach is good: + // There are a reasons why such an opt-in approach is good: // // 1. Launching one JVM per suite adds significant overhead for short-running suites. In // addition to JVM startup time and JIT warmup, it appears that initialization of Derby @@ -476,6 +476,8 @@ object SparkParallelTestGrouping { "org.apache.spark.ml.classification.LogisticRegressionSuite", "org.apache.spark.ml.classification.LinearSVCSuite", "org.apache.spark.sql.SQLQueryTestSuite", + "org.apache.spark.sql.hive.client.HadoopVersionInfoSuite", + "org.apache.spark.sql.hive.thriftserver.SparkExecuteStatementOperationSuite", "org.apache.spark.sql.hive.thriftserver.ThriftServerQueryTestSuite", "org.apache.spark.sql.hive.thriftserver.SparkSQLEnvSuite", "org.apache.spark.sql.hive.thriftserver.ui.ThriftServerPageSuite", @@ -1014,9 +1016,20 @@ object TestSettings { sys.props.get("test.exclude.tags").map { tags => Seq("--exclude-categories=" + tags) }.getOrElse(Nil): _*), + // Include tags defined in a system property + testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, + sys.props.get("test.include.tags").map { tags => + tags.split(",").flatMap { tag => Seq("-n", tag) }.toSeq + }.getOrElse(Nil): _*), + testOptions in Test += Tests.Argument(TestFrameworks.JUnit, + sys.props.get("test.include.tags").map { tags => + Seq("--include-categories=" + tags) + }.getOrElse(Nil): _*), // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), testOptions in Test += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), + // Required to detect Junit tests for each project, see also https://github.com/sbt/junit-interface/issues/35 + crossPaths := false, // Enable Junit testing. libraryDependencies += "com.novocode" % "junit-interface" % "0.11" % "test", // `parallelExecutionInTest` controls whether test suites belonging to the same SBT project diff --git a/python/pylintrc b/python/pylintrc index 6a675770da69a..26d2741d3b56f 100644 --- a/python/pylintrc +++ b/python/pylintrc @@ -27,7 +27,7 @@ # Profiled execution. profile=no -# Add files or directories to the blacklist. They should be base names, not +# Add files or directories to the ignoreList. They should be base names, not # paths. ignore=pyspark.heapq3 diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index a5d513262b266..2a19d233bc652 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -89,10 +89,7 @@ import sys import select import struct -if sys.version < '3': - import SocketServer -else: - import socketserver as SocketServer +import socketserver as SocketServer import threading from pyspark.serializers import read_int, PickleSerializer diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 803d857055dc0..c2daf7600ff26 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -20,16 +20,12 @@ import sys from tempfile import NamedTemporaryFile import threading +import pickle from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ChunkedStream, pickle_protocol -from pyspark.util import _exception_message, print_exec +from pyspark.util import print_exec -if sys.version < '3': - import cPickle as pickle -else: - import pickle - unicode = str __all__ = ['Broadcast'] @@ -113,7 +109,7 @@ def dump(self, value, f): raise except Exception as e: msg = "Could not serialize broadcast: %s: %s" \ - % (e.__class__.__name__, _exception_message(e)) + % (e.__class__.__name__, str(e)) print_exec(sys.stderr) raise pickle.PicklingError(msg) f.close() diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 09d3a5e7cfb6f..af49c77a2d98c 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -87,8 +87,8 @@ PY2 = True PY2_WRAPPER_DESCRIPTOR_TYPE = type(object.__init__) PY2_METHOD_WRAPPER_TYPE = type(object.__eq__) - PY2_CLASS_DICT_BLACKLIST = (PY2_METHOD_WRAPPER_TYPE, - PY2_WRAPPER_DESCRIPTOR_TYPE) + PY2_CLASS_DICT_SKIP_PICKLE_METHOD_TYPE = (PY2_METHOD_WRAPPER_TYPE, + PY2_WRAPPER_DESCRIPTOR_TYPE) else: types.ClassType = type from pickle import _Pickler as Pickler @@ -327,7 +327,7 @@ def _extract_class_dict(cls): if hasattr(value, "im_func"): if value.im_func is getattr(base_value, "im_func", None): to_remove.append(name) - elif isinstance(value, PY2_CLASS_DICT_BLACKLIST): + elif isinstance(value, PY2_CLASS_DICT_SKIP_PICKLE_METHOD_TYPE): # On Python 2 we have no way to pickle those specific # methods types nor to check that they are actually # inherited. So we assume that they are always inherited diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 2024260868197..efd8b6d633e0c 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -22,14 +22,14 @@ >>> conf.setMaster("local").setAppName("My app") >>> conf.get("spark.master") -u'local' +'local' >>> conf.get("spark.app.name") -u'My app' +'My app' >>> sc = SparkContext(conf=conf) >>> sc.master -u'local' +'local' >>> sc.appName -u'My app' +'My app' >>> sc.sparkHome is None True @@ -37,21 +37,21 @@ >>> conf.setSparkHome("/path") >>> conf.get("spark.home") -u'/path' +'/path' >>> conf.setExecutorEnv("VAR1", "value1") >>> conf.setExecutorEnv(pairs = [("VAR3", "value3"), ("VAR4", "value4")]) >>> conf.get("spark.executorEnv.VAR1") -u'value1' +'value1' >>> print(conf.toDebugString()) spark.executorEnv.VAR1=value1 spark.executorEnv.VAR3=value3 spark.executorEnv.VAR4=value4 spark.home=/path >>> sorted(conf.getAll(), key=lambda p: p[0]) -[(u'spark.executorEnv.VAR1', u'value1'), (u'spark.executorEnv.VAR3', u'value3'), \ -(u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')] +[('spark.executorEnv.VAR1', 'value1'), ('spark.executorEnv.VAR3', 'value3'), \ +('spark.executorEnv.VAR4', 'value4'), ('spark.home', '/path')] >>> conf._jconf.setExecutorEnv("VAR5", "value5") JavaObject id... >>> print(conf.toDebugString()) @@ -65,11 +65,6 @@ __all__ = ['SparkConf'] import sys -import re - -if sys.version > '3': - unicode = str - __doc__ = re.sub(r"(\W|^)[uU](['])", r'\1\2', __doc__) class SparkConf(object): @@ -124,9 +119,9 @@ def set(self, key, value): """Set a configuration property.""" # Try to set self._jconf first if JVM is created, set self._conf if JVM is not created yet. if self._jconf is not None: - self._jconf.set(key, unicode(value)) + self._jconf.set(key, str(value)) else: - self._conf[key] = unicode(value) + self._conf[key] = str(value) return self def setIfMissing(self, key, value): diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 32d69edb171db..2e105cc38260d 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -21,6 +21,7 @@ import sys import threading import warnings +import importlib from threading import RLock from tempfile import NamedTemporaryFile @@ -37,14 +38,12 @@ PairDeserializer, AutoBatchedSerializer, NoOpSerializer, ChunkedStream from pyspark.storagelevel import StorageLevel from pyspark.resource.information import ResourceInformation -from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix +from pyspark.rdd import RDD, _load_from_socket +from pyspark.taskcontext import TaskContext from pyspark.traceback_utils import CallSite, first_spark_call from pyspark.status import StatusTracker from pyspark.profiler import ProfilerCollector, BasicProfiler -if sys.version > '3': - xrange = range - __all__ = ['SparkContext'] @@ -118,6 +117,9 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, ... ValueError:... """ + # In order to prevent SparkContext from being created in executors. + SparkContext._assert_on_driver() + self._callsite = first_spark_call() or CallSite(None, None, None) if gateway is not None and gateway.gateway_parameters.auth_token is None: raise ValueError( @@ -209,15 +211,6 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') self.pythonVer = "%d.%d" % sys.version_info[:2] - if sys.version_info < (3, 6): - with warnings.catch_warnings(): - warnings.simplefilter("once") - warnings.warn( - "Support for Python 2 and Python 3 prior to version 3.6 is deprecated as " - "of Spark 3.0. See also the plan for dropping Python 2 support at " - "https://spark.apache.org/news/plan-for-dropping-python-2-support.html.", - DeprecationWarning) - # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to @@ -394,7 +387,6 @@ def version(self): return self._jsc.version() @property - @ignore_unicode_prefix def applicationId(self): """ A unique identifier for the Spark application. @@ -404,7 +396,7 @@ def applicationId(self): * in case of YARN something like 'application_1433865536131_34483' >>> sc.applicationId # doctest: +ELLIPSIS - u'local-...' + 'local-...' """ return self._jsc.sc().applicationId() @@ -486,20 +478,20 @@ def range(self, start, end=None, step=1, numSlices=None): end = start start = 0 - return self.parallelize(xrange(start, end, step), numSlices) + return self.parallelize(range(start, end, step), numSlices) def parallelize(self, c, numSlices=None): """ - Distribute a local Python collection to form an RDD. Using xrange + Distribute a local Python collection to form an RDD. Using range is recommended if the input represents a range for performance. >>> sc.parallelize([0, 2, 3, 4, 6], 5).glom().collect() [[0], [2], [3], [4], [6]] - >>> sc.parallelize(xrange(0, 6, 2), 5).glom().collect() + >>> sc.parallelize(range(0, 6, 2), 5).glom().collect() [[], [0], [], [2], [4]] """ numSlices = int(numSlices) if numSlices is not None else self.defaultParallelism - if isinstance(c, xrange): + if isinstance(c, range): size = len(c) if size == 0: return self.parallelize([], numSlices) @@ -518,7 +510,7 @@ def f(split, iterator): # the empty iterator to a list, thus make sure worker reuse takes effect. # See more details in SPARK-26549. assert len(list(iterator)) == 0 - return xrange(getStart(split), getStart(split + 1), step) + return range(getStart(split), getStart(split + 1), step) return self.parallelize([], numSlices).mapPartitionsWithIndex(f) @@ -587,7 +579,6 @@ def pickleFile(self, name, minPartitions=None): minPartitions = minPartitions or self.defaultMinPartitions return RDD(self._jsc.objectFile(name, minPartitions), self) - @ignore_unicode_prefix def textFile(self, name, minPartitions=None, use_unicode=True): """ Read a text file from HDFS, a local file system (available on all @@ -604,13 +595,12 @@ def textFile(self, name, minPartitions=None, use_unicode=True): ... _ = testFile.write("Hello world!") >>> textFile = sc.textFile(path) >>> textFile.collect() - [u'Hello world!'] + ['Hello world!'] """ minPartitions = minPartitions or min(self.defaultParallelism, 2) return RDD(self._jsc.textFile(name, minPartitions), self, UTF8Deserializer(use_unicode)) - @ignore_unicode_prefix def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): """ Read a directory of text files from HDFS, a local file system @@ -654,7 +644,7 @@ def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): ... _ = file2.write("2") >>> textFiles = sc.wholeTextFiles(dirPath) >>> sorted(textFiles.collect()) - [(u'.../1.txt', u'1'), (u'.../2.txt', u'2')] + [('.../1.txt', '1'), ('.../2.txt', '2')] """ minPartitions = minPartitions or self.defaultMinPartitions return RDD(self._jsc.wholeTextFiles(path, minPartitions), self, @@ -842,7 +832,6 @@ def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) return RDD(jrdd, self, input_deserializer) - @ignore_unicode_prefix def union(self, rdds): """ Build the union of a list of RDDs. @@ -856,10 +845,10 @@ def union(self, rdds): ... _ = testFile.write("Hello") >>> textFile = sc.textFile(path) >>> textFile.collect() - [u'Hello'] + ['Hello'] >>> parallelized = sc.parallelize(["World!"]) >>> sorted(sc.union([textFile, parallelized]).collect()) - [u'Hello', 'World!'] + ['Hello', 'World!'] """ first_jrdd_deserializer = rdds[0]._jrdd_deserializer if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): @@ -955,9 +944,8 @@ def addPyFile(self, path): self._python_includes.append(filename) # for tests in local mode sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) - if sys.version > '3': - import importlib - importlib.invalidate_caches() + + importlib.invalidate_caches() def setCheckpointDir(self, dirName): """ @@ -1145,6 +1133,16 @@ def resources(self): resources[name] = ResourceInformation(name, addrs) return resources + @staticmethod + def _assert_on_driver(): + """ + Called to ensure that SparkContext is created only on the Driver. + + Throws an exception if a SparkContext is about to be created in executors. + """ + if TaskContext.get() is not None: + raise Exception("SparkContext should only be created and accessed on the driver.") + def _test(): import atexit diff --git a/python/pyspark/find_spark_home.py b/python/pyspark/find_spark_home.py index 52f6ea9a37100..920c04009dd11 100755 --- a/python/pyspark/find_spark_home.py +++ b/python/pyspark/find_spark_home.py @@ -20,7 +20,6 @@ # This script attempt to determine the correct setting for SPARK_HOME given # that Spark may have been installed on the system with pip. -from __future__ import print_function import os import sys @@ -41,26 +40,15 @@ def is_spark_home(path): # Add the path of the PySpark module if it exists import_error_raised = False - if sys.version < "3": - import imp - try: - module_home = imp.find_module("pyspark")[1] - paths.append(module_home) - # If we are installed in edit mode also look two dirs up - paths.append(os.path.join(module_home, "../../")) - except ImportError: - # Not pip installed no worries - import_error_raised = True - else: - from importlib.util import find_spec - try: - module_home = os.path.dirname(find_spec("pyspark").origin) - paths.append(module_home) - # If we are installed in edit mode also look two dirs up - paths.append(os.path.join(module_home, "../../")) - except ImportError: - # Not pip installed no worries - import_error_raised = True + from importlib.util import find_spec + try: + module_home = os.path.dirname(find_spec("pyspark").origin) + paths.append(module_home) + # If we are installed in edit mode also look two dirs up + paths.append(os.path.join(module_home, "../../")) + except ImportError: + # Not pip installed no worries + import_error_raised = True # Normalize the paths paths = [os.path.abspath(p) for p in paths] @@ -84,5 +72,6 @@ def is_spark_home(path): "'PYSPARK_PYTHON=python3 pyspark'.\n", file=sys.stderr) sys.exit(-1) + if __name__ == "__main__": print(_find_spark_home()) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 0daf09b17a82a..fba92a96ae1a1 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -17,7 +17,6 @@ import atexit import os -import sys import signal import shlex import shutil @@ -27,14 +26,10 @@ import time from subprocess import Popen, PIPE -if sys.version >= '3': - xrange = range - from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters from py4j.clientserver import ClientServer, JavaParameters, PythonParameters from pyspark.find_spark_home import _find_spark_home from pyspark.serializers import read_int, write_with_length, UTF8Deserializer -from pyspark.util import _exception_message def launch_gateway(conf=None, popen_kwargs=None): @@ -197,7 +192,7 @@ def local_connect_and_auth(port, auth_secret): _do_server_auth(sockfile, auth_secret) return (sockfile, sock) except socket.error as e: - emsg = _exception_message(e) + emsg = str(e) errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg)) sock.close() sock = None diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index d70932a1bc6fc..4f2d33adbc7e7 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -16,20 +16,20 @@ # import operator -import sys +import warnings from abc import ABCMeta, abstractmethod, abstractproperty from multiprocessing.pool import ThreadPool -from pyspark import since, keyword_only +from pyspark import keyword_only from pyspark.ml import Estimator, Predictor, PredictionModel, Model from pyspark.ml.param.shared import * from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \ _TreeEnsembleModel, _RandomForestParams, _GBTParams, \ - _HasVarianceImpurity, _TreeClassifierParams, _TreeEnsembleParams + _HasVarianceImpurity, _TreeClassifierParams from pyspark.ml.regression import _FactorizationMachinesParams, DecisionTreeRegressionModel from pyspark.ml.util import * from pyspark.ml.base import _PredictorParams -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \ +from pyspark.ml.wrapper import JavaParams, \ JavaPredictor, JavaPredictionModel, JavaWrapper from pyspark.ml.common import inherit_doc, _java2py, _py2java from pyspark.ml.linalg import Vectors @@ -52,7 +52,8 @@ 'NaiveBayes', 'NaiveBayesModel', 'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel', 'OneVsRest', 'OneVsRestModel', - 'FMClassifier', 'FMClassificationModel'] + 'FMClassifier', 'FMClassificationModel', 'FMClassificationSummary', + 'FMClassificationTrainingSummary'] class _ClassifierParams(HasRawPredictionCol, _PredictorParams): @@ -2421,6 +2422,10 @@ class _MultilayerPerceptronParams(_ProbabilisticClassifierParams, HasSeed, HasMa initialWeights = Param(Params._dummy(), "initialWeights", "The initial weights of the model.", typeConverter=TypeConverters.toVector) + def __init__(self): + super(_MultilayerPerceptronParams, self).__init__() + self._setDefault(maxIter=100, tol=1E-6, blockSize=128, stepSize=0.03, solver="l-bfgs") + @since("1.6.0") def getLayers(self): """ @@ -2524,7 +2529,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(MultilayerPerceptronClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) - self._setDefault(maxIter=100, tol=1E-6, blockSize=128, stepSize=0.03, solver="l-bfgs") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -3120,9 +3124,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(FMClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.FMClassifier", self.uid) - self._setDefault(factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, - miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, - tol=1e-6, solver="adamW") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -3226,7 +3227,7 @@ def setRegParam(self, value): class FMClassificationModel(_JavaProbabilisticClassificationModel, _FactorizationMachinesParams, - JavaMLWritable, JavaMLReadable): + JavaMLWritable, JavaMLReadable, HasTrainingSummary): """ Model fitted by :class:`FMClassifier`. @@ -3257,6 +3258,49 @@ def factors(self): """ return self._call_java("factors") + @since("3.1.0") + def summary(self): + """ + Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model + trained on the training set. An exception is thrown if `trainingSummary is None`. + """ + if self.hasSummary: + return FMClassificationTrainingSummary(super(FMClassificationModel, self).summary) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) + + @since("3.1.0") + def evaluate(self, dataset): + """ + Evaluates the model on a test dataset. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + if not isinstance(dataset, DataFrame): + raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) + java_fm_summary = self._call_java("evaluate", dataset) + return FMClassificationSummary(java_fm_summary) + + +class FMClassificationSummary(_BinaryClassificationSummary): + """ + Abstraction for FMClassifier Results for a given model. + .. versionadded:: 3.1.0 + """ + pass + + +@inherit_doc +class FMClassificationTrainingSummary(FMClassificationSummary, _TrainingSummary): + """ + Abstraction for FMClassifier Training results. + .. versionadded:: 3.1.0 + """ + pass + if __name__ == "__main__": import doctest diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py index 387c5d7309dea..4e1d7f93aef9b 100644 --- a/python/pyspark/ml/common.py +++ b/python/pyspark/ml/common.py @@ -15,11 +15,6 @@ # limitations under the License. # -import sys -if sys.version >= '3': - long = int - unicode = str - import py4j.protocol from py4j.protocol import Py4JJavaError from py4j.java_gateway import JavaObject @@ -79,7 +74,7 @@ def _py2java(sc, obj): obj = [_py2java(sc, x) for x in obj] elif isinstance(obj, JavaObject): pass - elif isinstance(obj, (int, long, float, bool, bytes, unicode)): + elif isinstance(obj, (int, float, bool, bytes, str)): pass else: data = bytearray(PickleSerializer().dumps(obj)) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 498629cea846c..a319dace6869a 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -15,12 +15,7 @@ # limitations under the License. # -import sys -if sys.version > '3': - basestring = str - from pyspark import since, keyword_only, SparkContext -from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.linalg import _convert_to_vector from pyspark.ml.param.shared import * from pyspark.ml.util import JavaMLReadable, JavaMLWritable @@ -2178,7 +2173,6 @@ def originalMax(self): @inherit_doc -@ignore_unicode_prefix class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ A feature transformer that converts the input array of strings into an array of n-grams. Null @@ -2196,15 +2190,15 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWr >>> ngram.setOutputCol("nGrams") NGram... >>> ngram.transform(df).head() - Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', u'c d', u'd e']) + Row(inputTokens=['a', 'b', 'c', 'd', 'e'], nGrams=['a b', 'b c', 'c d', 'd e']) >>> # Change n-gram length >>> ngram.setParams(n=4).transform(df).head() - Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e']) + Row(inputTokens=['a', 'b', 'c', 'd', 'e'], nGrams=['a b c d', 'b c d e']) >>> # Temporarily modify output column. >>> ngram.transform(df, {ngram.outputCol: "output"}).head() - Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], output=[u'a b c d', u'b c d e']) + Row(inputTokens=['a', 'b', 'c', 'd', 'e'], output=['a b c d', 'b c d e']) >>> ngram.transform(df).head() - Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e']) + Row(inputTokens=['a', 'b', 'c', 'd', 'e'], nGrams=['a b c d', 'b c d e']) >>> # Must use keyword arguments to specify params. >>> ngram.setParams("text") Traceback (most recent call last): @@ -3082,7 +3076,6 @@ def range(self): @inherit_doc -@ignore_unicode_prefix class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ A regex based tokenizer that extracts tokens either by using the @@ -3099,15 +3092,15 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, >>> reTokenizer.setOutputCol("words") RegexTokenizer... >>> reTokenizer.transform(df).head() - Row(text=u'A B c', words=[u'a', u'b', u'c']) + Row(text='A B c', words=['a', 'b', 'c']) >>> # Change a parameter. >>> reTokenizer.setParams(outputCol="tokens").transform(df).head() - Row(text=u'A B c', tokens=[u'a', u'b', u'c']) + Row(text='A B c', tokens=['a', 'b', 'c']) >>> # Temporarily modify a parameter. >>> reTokenizer.transform(df, {reTokenizer.outputCol: "words"}).head() - Row(text=u'A B c', words=[u'a', u'b', u'c']) + Row(text='A B c', words=['a', 'b', 'c']) >>> reTokenizer.transform(df).head() - Row(text=u'A B c', tokens=[u'a', u'b', u'c']) + Row(text='A B c', tokens=['a', 'b', 'c']) >>> # Must use keyword arguments to specify params. >>> reTokenizer.setParams("text") Traceback (most recent call last): @@ -3935,7 +3928,6 @@ def loadDefaultStopWords(language): @inherit_doc -@ignore_unicode_prefix class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ A tokenizer that converts the input string to lowercase and then @@ -3946,15 +3938,15 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Java >>> tokenizer.setInputCol("text") Tokenizer... >>> tokenizer.transform(df).head() - Row(text=u'a b c', words=[u'a', u'b', u'c']) + Row(text='a b c', words=['a', 'b', 'c']) >>> # Change a parameter. >>> tokenizer.setParams(outputCol="tokens").transform(df).head() - Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + Row(text='a b c', tokens=['a', 'b', 'c']) >>> # Temporarily modify a parameter. >>> tokenizer.transform(df, {tokenizer.outputCol: "words"}).head() - Row(text=u'a b c', words=[u'a', u'b', u'c']) + Row(text='a b c', words=['a', 'b', 'c']) >>> tokenizer.transform(df).head() - Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + Row(text='a b c', tokens=['a', 'b', 'c']) >>> # Must use keyword arguments to specify params. >>> tokenizer.setParams("text") Traceback (most recent call last): @@ -4476,7 +4468,6 @@ def getMaxSentenceLength(self): @inherit_doc -@ignore_unicode_prefix class Word2Vec(JavaEstimator, _Word2VecParams, JavaMLReadable, JavaMLWritable): """ Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further @@ -4505,7 +4496,7 @@ class Word2Vec(JavaEstimator, _Word2VecParams, JavaMLReadable, JavaMLWritable): +----+--------------------+ ... >>> model.findSynonymsArray("a", 2) - [(u'b', 0.015859870240092278), (u'c', -0.5680795907974243)] + [('b', 0.015859870240092278), ('c', -0.5680795907974243)] >>> from pyspark.sql.functions import format_number as fmt >>> model.findSynonyms("a", 2).select("word", fmt("similarity", 5).alias("similarity")).show() +----+----------+ @@ -4668,7 +4659,7 @@ def findSynonyms(self, word, num): Returns a dataframe with two fields word and similarity (which gives the cosine similarity). """ - if not isinstance(word, basestring): + if not isinstance(word, str): word = _convert_to_vector(word) return self._call_java("findSynonyms", word, num) @@ -4680,7 +4671,7 @@ def findSynonymsArray(self, word, num): Returns an array with two fields word and similarity (which gives the cosine similarity). """ - if not isinstance(word, basestring): + if not isinstance(word, str): word = _convert_to_vector(word) tuples = self._java_obj.findSynonymsArray(word, num) return list(map(lambda st: (st._1(), st._2()), list(tuples))) @@ -5745,6 +5736,7 @@ def selectedFeatures(self): if __name__ == "__main__": import doctest + import sys import tempfile import pyspark.ml.feature diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index 7a5591f3fbf76..b91788a82c19a 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -15,8 +15,7 @@ # limitations under the License. # -from pyspark import keyword_only, since -from pyspark.rdd import ignore_unicode_prefix +from pyspark import keyword_only from pyspark.sql import DataFrame from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams @@ -132,7 +131,6 @@ def associationRules(self): return self._call_java("associationRules") -@ignore_unicode_prefix class FPGrowth(JavaEstimator, _FPGrowthParams, JavaMLWritable, JavaMLReadable): r""" A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in @@ -193,7 +191,7 @@ class FPGrowth(JavaEstimator, _FPGrowthParams, JavaMLWritable, JavaMLReadable): ... >>> new_data = spark.createDataFrame([(["t", "s"], )], ["items"]) >>> sorted(fpm.transform(new_data).first().newPrediction) - [u'x', u'y', u'z'] + ['x', 'y', 'z'] .. versionadded:: 2.2.0 """ diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 4fb1036fbab89..20b24559b182d 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -25,14 +25,13 @@ """ import sys -import warnings import numpy as np from distutils.version import LooseVersion from pyspark import SparkContext from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import SparkSession __all__ = ["ImageSchema"] diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index a79d5e5dcbb16..8be440da4fef8 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -27,18 +27,8 @@ import array import struct -if sys.version >= '3': - basestring = str - xrange = range - import copyreg as copy_reg - long = int -else: - from itertools import izip as zip - import copy_reg - import numpy as np -from pyspark import since from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ IntegerType, ByteType, BooleanType @@ -47,13 +37,6 @@ 'Matrix', 'DenseMatrix', 'SparseMatrix', 'Matrices'] -if sys.version_info[:2] == (2, 7): - # speed up pickling array in Python 2.7 - def fast_pickle_array(ar): - return array.array, (ar.typecode, ar.tostring()) - copy_reg.pickle(array.array, fast_pickle_array) - - # Check whether we have SciPy. MLlib works without it too, but if we have it, some methods, # such as _dot and _serialize_double_vector, start to support scipy.sparse matrices. @@ -68,7 +51,7 @@ def fast_pickle_array(ar): def _convert_to_vector(l): if isinstance(l, Vector): return l - elif type(l) in (array.array, np.array, np.ndarray, list, tuple, xrange): + elif type(l) in (array.array, np.array, np.ndarray, list, tuple, range): return DenseVector(l) elif _have_scipy and scipy.sparse.issparse(l): assert l.shape[1] == 1, "Expected column vector" @@ -102,7 +85,7 @@ def _vector_size(v): """ if isinstance(v, Vector): return len(v) - elif type(v) in (array.array, list, tuple, xrange): + elif type(v) in (array.array, list, tuple, range): return len(v) elif type(v) == np.ndarray: if v.ndim == 1 or (v.ndim == 2 and v.shape[1] == 1): @@ -415,7 +398,7 @@ def __eq__(self, other): elif isinstance(other, SparseVector): if len(self) != other.size: return False - return Vectors._equals(list(xrange(len(self))), self.array, other.indices, other.values) + return Vectors._equals(list(range(len(self))), self.array, other.indices, other.values) return False def __ne__(self, other): @@ -520,7 +503,7 @@ def __init__(self, size, *args): self.indices = np.array(args[0], dtype=np.int32) self.values = np.array(args[1], dtype=np.float64) assert len(self.indices) == len(self.values), "index and value arrays not same length" - for i in xrange(len(self.indices) - 1): + for i in range(len(self.indices) - 1): if self.indices[i] >= self.indices[i + 1]: raise TypeError( "Indices %s and %s are not strictly increasing" @@ -699,7 +682,7 @@ def __repr__(self): inds = self.indices vals = self.values entries = ", ".join(["{0}: {1}".format(inds[i], _format_float(vals[i])) - for i in xrange(len(inds))]) + for i in range(len(inds))]) return "SparseVector({0}, {{{1}}})".format(self.size, entries) def __eq__(self, other): @@ -709,7 +692,7 @@ def __eq__(self, other): elif isinstance(other, DenseVector): if self.size != len(other): return False - return Vectors._equals(self.indices, self.values, list(xrange(len(other))), other.array) + return Vectors._equals(self.indices, self.values, list(range(len(other))), other.array) return False def __getitem__(self, index): @@ -791,7 +774,7 @@ def dense(*elements): >>> Vectors.dense(1.0, 2.0) DenseVector([1.0, 2.0]) """ - if len(elements) == 1 and not isinstance(elements[0], (float, int, long)): + if len(elements) == 1 and not isinstance(elements[0], (float, int)): # it's list, numpy.array or other iterable object. elements = elements[0] return DenseVector(elements) @@ -1124,7 +1107,7 @@ def toArray(self): Return a numpy.ndarray """ A = np.zeros((self.numRows, self.numCols), dtype=np.float64, order='F') - for k in xrange(self.colPtrs.size - 1): + for k in range(self.colPtrs.size - 1): startptr = self.colPtrs[k] endptr = self.colPtrs[k + 1] if self.isTransposed: diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 1be8755c7b982..96b07bfa5f14f 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -16,15 +16,10 @@ # import array import sys -if sys.version > '3': - basestring = str - xrange = range - unicode = str - from abc import ABCMeta import copy -import numpy as np +import numpy as np from py4j.java_gateway import JavaObject from pyspark.ml.linalg import DenseVector, Vector, Matrix @@ -93,12 +88,12 @@ def _is_integer(value): @staticmethod def _can_convert_to_list(value): vtype = type(value) - return vtype in [list, np.ndarray, tuple, xrange, array.array] or isinstance(value, Vector) + return vtype in [list, np.ndarray, tuple, range, array.array] or isinstance(value, Vector) @staticmethod def _can_convert_to_string(value): vtype = type(value) - return isinstance(value, basestring) or vtype in [np.unicode_, np.string_, np.str_] + return isinstance(value, str) or vtype in [np.unicode_, np.string_, np.str_] @staticmethod def identity(value): @@ -114,7 +109,7 @@ def toList(value): """ if type(value) == list: return value - elif type(value) in [np.ndarray, tuple, xrange, array.array]: + elif type(value) in [np.ndarray, tuple, range, array.array]: return list(value) elif isinstance(value, Vector): return list(value.toArray()) @@ -211,12 +206,10 @@ def toString(value): """ Convert a value to a string, if possible. """ - if isinstance(value, basestring): + if isinstance(value, str): return value - elif type(value) in [np.string_, np.str_]: + elif type(value) in [np.string_, np.str_, np.unicode_]: return str(value) - elif type(value) == np.unicode_: - return unicode(value) else: raise TypeError("Could not convert %s to string type" % type(value)) @@ -338,7 +331,7 @@ def hasParam(self, paramName): Tests whether this instance contains a param with a given (string) name. """ - if isinstance(paramName, basestring): + if isinstance(paramName, str): p = getattr(self, paramName, None) return isinstance(p, Param) else: @@ -421,7 +414,7 @@ def _resolveParam(self, param): if isinstance(param, Param): self._shouldOwn(param) return param - elif isinstance(param, basestring): + elif isinstance(param, str): return self.getParam(param) else: raise ValueError("Cannot resolve %r as a param." % param) @@ -510,7 +503,7 @@ def _resetUid(self, newUid): :return: same instance, but with the uid and Param.parent values updated, including within param maps """ - newUid = unicode(newUid) + newUid = str(newUid) self.uid = newUid newDefaultParamMap = dict() newParamMap = dict() diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 2086e831f4282..bc1ea87ad629c 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - header = """# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 53d07ec9660d9..eacb8b82b5244 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -16,12 +16,8 @@ # import sys -import os -if sys.version > '3': - basestring = str - -from pyspark import since, keyword_only, SparkContext +from pyspark import keyword_only from pyspark.ml.base import Estimator, Model, Transformer from pyspark.ml.param import Param, Params from pyspark.ml.util import * diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index b58255ea12afc..e82a35c8e78f1 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1891,6 +1891,11 @@ class _GeneralizedLinearRegressionParams(_PredictorParams, HasFitIntercept, HasM "or empty, we treat all instance offsets as 0.0", typeConverter=TypeConverters.toString) + def __init__(self): + super(_GeneralizedLinearRegressionParams, self).__init__() + self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls", + variancePower=0.0, aggregationDepth=2) + @since("2.0.0") def getFamily(self): """ @@ -2023,8 +2028,6 @@ def __init__(self, labelCol="label", featuresCol="features", predictionCol="pred super(GeneralizedLinearRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid) - self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls", - variancePower=0.0, aggregationDepth=2) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -2398,6 +2401,12 @@ class _FactorizationMachinesParams(_PredictorParams, HasMaxIter, HasStepSize, Ha solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " + "options: gd, adamW. (Default adamW)", typeConverter=TypeConverters.toString) + def __init__(self): + super(_FactorizationMachinesParams, self).__init__() + self._setDefault(factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, + miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, + tol=1e-6, solver="adamW") + @since("3.0.0") def getFactorSize(self): """ @@ -2489,9 +2498,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(FMRegressor, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.FMRegressor", self.uid) - self._setDefault(factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, - miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, - tol=1e-6, solver="adamW") kwargs = self._input_kwargs self.setParams(**kwargs) diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py index 4c6bfa696b110..7856a317c261d 100644 --- a/python/pyspark/ml/tests/test_feature.py +++ b/python/pyspark/ml/tests/test_feature.py @@ -19,9 +19,6 @@ import sys import unittest -if sys.version > '3': - basestring = str - from pyspark.ml.feature import Binarizer, CountVectorizer, CountVectorizerModel, HashingTF, IDF, \ NGram, RFormula, StopWordsRemover, StringIndexer, StringIndexerModel, VectorSizeHint from pyspark.ml.linalg import DenseVector, SparseVector, Vectors @@ -91,7 +88,7 @@ def test_stopwordsremover(self): transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, ["panda"]) self.assertEqual(type(stopWordRemover.getStopWords()), list) - self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], basestring)) + self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], str)) # Custom stopwords = ["panda"] stopWordRemover.setStopWords(stopwords) diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py index 1b2b1914cc036..e1abd59a2d7b2 100644 --- a/python/pyspark/ml/tests/test_param.py +++ b/python/pyspark/ml/tests/test_param.py @@ -35,10 +35,6 @@ from pyspark.testing.mlutils import check_params, PySparkTestCase, SparkSessionTestCase -if sys.version > '3': - xrange = range - - class ParamTypeConversionTests(PySparkTestCase): """ Test that param type conversion happens. @@ -67,14 +63,14 @@ def test_vector(self): def test_list(self): l = [0, 1] for lst_like in [l, np.array(l), DenseVector(l), SparseVector(len(l), range(len(l)), l), - pyarray.array('l', l), xrange(2), tuple(l)]: + pyarray.array('l', l), range(2), tuple(l)]: converted = TypeConverters.toList(lst_like) self.assertEqual(type(converted), list) self.assertListEqual(converted, l) def test_list_int(self): for indices in [[1.0, 2.0], np.array([1.0, 2.0]), DenseVector([1.0, 2.0]), - SparseVector(2, {0: 1.0, 1: 2.0}), xrange(1, 3), (1.0, 2.0), + SparseVector(2, {0: 1.0, 1: 2.0}), range(1, 3), (1.0, 2.0), pyarray.array('d', [1.0, 2.0])]: vs = VectorSlicer(indices=indices) self.assertListEqual(vs.getIndices(), [1, 2]) @@ -200,12 +196,7 @@ def test_resolveparam(self): self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter) self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter) - if sys.version_info[0] >= 3: - # In Python 3, it is allowed to get/set attributes with non-ascii characters. - e_cls = AttributeError - else: - e_cls = UnicodeEncodeError - self.assertRaises(e_cls, lambda: testParams._resolveParam(u"아")) + self.assertRaises(AttributeError, lambda: testParams._resolveParam(u"아")) def test_params(self): testParams = TestParams() diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py index d4edcc26e17ac..2f6d451851b4b 100644 --- a/python/pyspark/ml/tests/test_persistence.py +++ b/python/pyspark/ml/tests/test_persistence.py @@ -21,19 +21,78 @@ import unittest from pyspark.ml import Transformer -from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, OneVsRest, \ - OneVsRestModel +from pyspark.ml.classification import DecisionTreeClassifier, FMClassifier, \ + FMClassificationModel, LogisticRegression, MultilayerPerceptronClassifier, \ + MultilayerPerceptronClassificationModel, OneVsRest, OneVsRestModel from pyspark.ml.clustering import KMeans from pyspark.ml.feature import Binarizer, HashingTF, PCA from pyspark.ml.linalg import Vectors from pyspark.ml.param import Params from pyspark.ml.pipeline import Pipeline, PipelineModel -from pyspark.ml.regression import DecisionTreeRegressor, LinearRegression +from pyspark.ml.regression import DecisionTreeRegressor, GeneralizedLinearRegression, \ + GeneralizedLinearRegressionModel, \ + LinearRegression from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWriter from pyspark.ml.wrapper import JavaParams from pyspark.testing.mlutils import MockUnaryTransformer, SparkSessionTestCase +class TestDefaultSolver(SparkSessionTestCase): + + def test_multilayer_load(self): + df = self.spark.createDataFrame([(0.0, Vectors.dense([0.0, 0.0])), + (1.0, Vectors.dense([0.0, 1.0])), + (1.0, Vectors.dense([1.0, 0.0])), + (0.0, Vectors.dense([1.0, 1.0]))], + ["label", "features"]) + + mlp = MultilayerPerceptronClassifier(layers=[2, 2, 2], seed=123) + model = mlp.fit(df) + self.assertEqual(model.getSolver(), "l-bfgs") + transformed1 = model.transform(df) + path = tempfile.mkdtemp() + model_path = path + "/mlp" + model.save(model_path) + model2 = MultilayerPerceptronClassificationModel.load(model_path) + self.assertEqual(model2.getSolver(), "l-bfgs") + transformed2 = model2.transform(df) + self.assertEqual(transformed1.take(4), transformed2.take(4)) + + def test_fm_load(self): + df = self.spark.createDataFrame([(1.0, Vectors.dense(1.0)), + (0.0, Vectors.sparse(1, [], []))], + ["label", "features"]) + fm = FMClassifier(factorSize=2, maxIter=50, stepSize=2.0) + model = fm.fit(df) + self.assertEqual(model.getSolver(), "adamW") + transformed1 = model.transform(df) + path = tempfile.mkdtemp() + model_path = path + "/fm" + model.save(model_path) + model2 = FMClassificationModel.load(model_path) + self.assertEqual(model2.getSolver(), "adamW") + transformed2 = model2.transform(df) + self.assertEqual(transformed1.take(2), transformed2.take(2)) + + def test_glr_load(self): + df = self.spark.createDataFrame([(1.0, Vectors.dense(0.0, 0.0)), + (1.0, Vectors.dense(1.0, 2.0)), + (2.0, Vectors.dense(0.0, 0.0)), + (2.0, Vectors.dense(1.0, 1.0))], + ["label", "features"]) + glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p") + model = glr.fit(df) + self.assertEqual(model.getSolver(), "irls") + transformed1 = model.transform(df) + path = tempfile.mkdtemp() + model_path = path + "/glr" + model.save(model_path) + model2 = GeneralizedLinearRegressionModel.load(model_path) + self.assertEqual(model2.getSolver(), "irls") + transformed2 = model2.transform(df) + self.assertEqual(transformed1.take(4), transformed2.take(4)) + + class PersistenceTest(SparkSessionTestCase): def test_linear_regression(self): diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py index 7d905793188bf..d305be8b96cd4 100644 --- a/python/pyspark/ml/tests/test_training_summary.py +++ b/python/pyspark/ml/tests/test_training_summary.py @@ -18,11 +18,9 @@ import sys import unittest -if sys.version > '3': - basestring = str - -from pyspark.ml.classification import BinaryLogisticRegressionSummary, LinearSVC, \ - LinearSVCSummary, BinaryRandomForestClassificationSummary, LogisticRegression, \ +from pyspark.ml.classification import BinaryLogisticRegressionSummary, FMClassifier, \ + FMClassificationSummary, LinearSVC, LinearSVCSummary, \ + BinaryRandomForestClassificationSummary, LogisticRegression, \ LogisticRegressionSummary, RandomForestClassificationSummary, \ RandomForestClassifier from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans @@ -101,7 +99,7 @@ def test_glr_summary(self): self.assertEqual(s.residualDegreeOfFreedom, 1) self.assertEqual(s.residualDegreeOfFreedomNull, 2) self.assertEqual(s.rank, 1) - self.assertTrue(isinstance(s.solver, basestring)) + self.assertTrue(isinstance(s.solver, str)) self.assertTrue(isinstance(s.aic, float)) self.assertTrue(isinstance(s.deviance, float)) self.assertTrue(isinstance(s.nullDeviance, float)) @@ -312,6 +310,50 @@ def test_multiclass_randomforest_classification_summary(self): self.assertFalse(isinstance(sameSummary, BinaryRandomForestClassificationSummary)) self.assertAlmostEqual(sameSummary.accuracy, s.accuracy) + def test_fm_classification_summary(self): + df = self.spark.createDataFrame([(1.0, Vectors.dense(2.0)), + (0.0, Vectors.dense(2.0)), + (0.0, Vectors.dense(6.0)), + (1.0, Vectors.dense(3.0)) + ], + ["label", "features"]) + fm = FMClassifier(maxIter=5) + model = fm.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary() + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.scoreCol, "probability") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.predictionCol, "prediction") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.labels, list)) + self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.precisionByLabel, list)) + self.assertTrue(isinstance(s.recallByLabel, list)) + self.assertTrue(isinstance(s.fMeasureByLabel(), list)) + self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) + self.assertTrue(isinstance(s.roc, DataFrame)) + self.assertAlmostEqual(s.areaUnderROC, 0.625, 2) + self.assertTrue(isinstance(s.pr, DataFrame)) + self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) + self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) + self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) + self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2) + self.assertAlmostEqual(s.weightedRecall, 0.75, 2) + self.assertAlmostEqual(s.weightedPrecision, 0.8333333333333333, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 0.7333333333333334, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.7333333333333334, 2) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertTrue(isinstance(sameSummary, FMClassificationSummary)) + self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) + def test_gaussian_mixture_summary(self): data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), (Vectors.sparse(1, [], []),)] diff --git a/python/pyspark/ml/tree.py b/python/pyspark/ml/tree.py index a13b27ec8a79c..460c76fabc375 100644 --- a/python/pyspark/ml/tree.py +++ b/python/pyspark/ml/tree.py @@ -15,12 +15,10 @@ # limitations under the License. # -from pyspark import since, keyword_only from pyspark.ml.param.shared import * from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \ - JavaPredictor, JavaPredictionModel -from pyspark.ml.common import inherit_doc, _java2py, _py2java +from pyspark.ml.wrapper import JavaPredictionModel +from pyspark.ml.common import inherit_doc @inherit_doc diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index e00753b2ffc20..7f3d942e2e456 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -15,12 +15,11 @@ # limitations under the License. # import itertools -import sys from multiprocessing.pool import ThreadPool import numpy as np -from pyspark import since, keyword_only +from pyspark import keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.common import _py2java, _java2py from pyspark.ml.param import Params, Param, TypeConverters diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index aac2b38d3f57d..9ab6bfa9ba968 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -20,12 +20,6 @@ import os import time import uuid -import warnings - -if sys.version > '3': - basestring = str - unicode = str - long = int from pyspark import SparkContext, since from pyspark.ml.common import inherit_doc @@ -60,10 +54,10 @@ def __repr__(self): @classmethod def _randomUID(cls): """ - Generate a unique unicode id for the object. The default implementation + Generate a unique string id for the object. The default implementation concatenates the class name, "_", and 12 random hex chars. """ - return unicode(cls.__name__ + "_" + uuid.uuid4().hex[-12:]) + return str(cls.__name__ + "_" + uuid.uuid4().hex[-12:]) @inherit_doc @@ -170,8 +164,8 @@ def __init__(self, instance): def save(self, path): """Save the ML instance to the input path.""" - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) + if not isinstance(path, str): + raise TypeError("path should be a string, got type %s" % type(path)) self._jwrite.save(path) def overwrite(self): @@ -275,8 +269,8 @@ def __init__(self, clazz): def load(self, path): """Load the ML instance from the input path.""" - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) + if not isinstance(path, str): + raise TypeError("path should be a string, got type %s" % type(path)) java_obj = self._jread.load(path) if not hasattr(self._clazz, "_from_java"): raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r" @@ -430,7 +424,7 @@ def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None): for p in instance._defaultParamMap: jsonDefaultParams[p.name] = instance._defaultParamMap[p] - basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)), + basicMetadata = {"class": cls, "timestamp": int(round(time.time() * 1000)), "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams, "defaultParamMap": jsonDefaultParams} if extraMetadata is not None: diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index e59c6c7b250a8..c1d060a51cf9d 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -16,9 +16,6 @@ # from abc import ABCMeta, abstractmethod -import sys -if sys.version >= '3': - xrange = range from pyspark import since from pyspark import SparkContext @@ -26,7 +23,6 @@ from pyspark.ml import Estimator, Predictor, PredictionModel, Transformer, Model from pyspark.ml.base import _PredictorParams from pyspark.ml.param import Params -from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol from pyspark.ml.util import _jvm from pyspark.ml.common import inherit_doc, _java2py, _py2java @@ -99,15 +95,15 @@ def _new_java_array(pylist, java_class): # If pylist is a 2D array, then a 2D java array will be created. # The 2D array is a square, non-jagged 2D array that is big enough for all elements. inner_array_length = 0 - for i in xrange(len(pylist)): + for i in range(len(pylist)): inner_array_length = max(inner_array_length, len(pylist[i])) java_array = sc._gateway.new_array(java_class, len(pylist), inner_array_length) - for i in xrange(len(pylist)): - for j in xrange(len(pylist[i])): + for i in range(len(pylist)): + for j in range(len(pylist[i])): java_array[i][j] = pylist[i][j] else: java_array = sc._gateway.new_array(java_class, len(pylist)) - for i in xrange(len(pylist)): + for i in range(len(pylist)): java_array[i] = pylist[i] return java_array diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index ae26521ea96bf..6067693111547 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -21,8 +21,6 @@ The `pyspark.mllib` package is in maintenance mode as of the Spark 2.0.0 release to encourage migration to the DataFrame-based APIs under the `pyspark.ml` package. """ -from __future__ import absolute_import - # MLlib currently needs NumPy 1.4+, so complain if lower import numpy diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index e41e5c9cc8e89..85cfe583fd5c5 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -17,20 +17,13 @@ import sys import array as pyarray -import warnings - -if sys.version > '3': - xrange = range - basestring = str - from math import exp, log +from collections import namedtuple from numpy import array, random, tile -from collections import namedtuple - from pyspark import SparkContext, since -from pyspark.rdd import RDD, ignore_unicode_prefix +from pyspark.rdd import RDD from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector from pyspark.mllib.stat.distribution import MultivariateGaussian @@ -257,7 +250,7 @@ def predict(self, x): return x.map(self.predict) x = _convert_to_vector(x) - for i in xrange(len(self.centers)): + for i in range(len(self.centers)): distance = x.squared_distance(self.centers[i]) if distance < best_distance: best = i @@ -708,7 +701,7 @@ class StreamingKMeansModel(KMeansModel): >>> stkm = StreamingKMeansModel(initCenters, initWeights) >>> data = sc.parallelize([[-0.1, -0.1], [0.1, 0.1], ... [0.9, 0.9], [1.1, 1.1]]) - >>> stkm = stkm.update(data, 1.0, u"batches") + >>> stkm = stkm.update(data, 1.0, "batches") >>> stkm.centers array([[ 0., 0.], [ 1., 1.]]) @@ -720,7 +713,7 @@ class StreamingKMeansModel(KMeansModel): [3.0, 3.0] >>> decayFactor = 0.0 >>> data = sc.parallelize([DenseVector([1.5, 1.5]), DenseVector([0.2, 0.2])]) - >>> stkm = stkm.update(data, 0.0, u"batches") + >>> stkm = stkm.update(data, 0.0, "batches") >>> stkm.centers array([[ 0.2, 0.2], [ 1.5, 1.5]]) @@ -743,7 +736,6 @@ def clusterWeights(self): """Return the cluster weights.""" return self._clusterWeights - @ignore_unicode_prefix @since('1.5.0') def update(self, data, decayFactor, timeUnit): """Update the centroids, according to data @@ -979,8 +971,8 @@ def load(cls, sc, path): """ if not isinstance(sc, SparkContext): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) + if not isinstance(path, str): + raise TypeError("path should be a string, got type %s" % type(path)) model = callMLlibFunc("loadLDAModel", sc, path) return LDAModel(model) diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index bac8f350563ec..24e2f198251ad 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -15,11 +15,6 @@ # limitations under the License. # -import sys -if sys.version >= '3': - long = int - unicode = str - import py4j.protocol from py4j.protocol import Py4JJavaError from py4j.java_gateway import JavaObject @@ -81,7 +76,7 @@ def _py2java(sc, obj): obj = [_py2java(sc, x) for x in obj] elif isinstance(obj, JavaObject): pass - elif isinstance(obj, (int, long, float, bool, bytes, unicode)): + elif isinstance(obj, (int, float, bool, bytes, str)): pass else: data = bytearray(PickleSerializer().dumps(obj)) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 3efae6ff0ecc3..80a197eaa7494 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -18,21 +18,15 @@ """ Python package for feature in MLlib. """ -from __future__ import absolute_import - import sys import warnings -if sys.version >= '3': - basestring = str - unicode = str - from py4j.protocol import Py4JJavaError from pyspark import since -from pyspark.rdd import RDD, ignore_unicode_prefix +from pyspark.rdd import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import ( - Vector, Vectors, DenseVector, SparseVector, _convert_to_vector) + Vectors, DenseVector, SparseVector, _convert_to_vector) from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.util import JavaLoader, JavaSaveable @@ -616,7 +610,7 @@ def findSynonyms(self, word, num): .. note:: Local use only """ - if not isinstance(word, basestring): + if not isinstance(word, str): word = _convert_to_vector(word) words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) @@ -640,7 +634,6 @@ def load(cls, sc, path): return Word2VecModel(model) -@ignore_unicode_prefix class Word2Vec(object): """Word2Vec creates vector representation of words in a text corpus. The algorithm first constructs a vocabulary from the corpus @@ -668,7 +661,7 @@ class Word2Vec(object): >>> syms = model.findSynonyms("a", 2) >>> [s[0] for s in syms] - [u'b', u'c'] + ['b', 'c'] But querying for synonyms of a vector may return the word whose representation is that vector: @@ -676,7 +669,7 @@ class Word2Vec(object): >>> vec = model.transform("a") >>> syms = model.findSynonyms(vec, 2) >>> [s[0] for s in syms] - [u'a', u'b'] + ['a', 'b'] >>> import os, tempfile >>> path = tempfile.mkdtemp() @@ -686,7 +679,7 @@ class Word2Vec(object): True >>> syms = sameModel.findSynonyms("a", 2) >>> [s[0] for s in syms] - [u'b', u'c'] + ['b', 'c'] >>> from shutil import rmtree >>> try: ... rmtree(path) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index 373a141456b2f..cbbd7b351b20d 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -20,7 +20,6 @@ from collections import namedtuple from pyspark import since -from pyspark.rdd import ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc from pyspark.mllib.util import JavaSaveable, JavaLoader, inherit_doc @@ -28,7 +27,6 @@ @inherit_doc -@ignore_unicode_prefix class FPGrowthModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ A FP-Growth model for mining frequent itemsets @@ -38,7 +36,7 @@ class FPGrowthModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> rdd = sc.parallelize(data, 2) >>> model = FPGrowth.train(rdd, 0.6, 2) >>> sorted(model.freqItemsets().collect()) - [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... + [FreqItemset(items=['a'], freq=4), FreqItemset(items=['c'], freq=3), ... >>> model_path = temp_path + "/fpm" >>> model.save(sc, model_path) >>> sameModel = FPGrowthModel.load(sc, model_path) @@ -101,7 +99,6 @@ class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])): @inherit_doc -@ignore_unicode_prefix class PrefixSpanModel(JavaModelWrapper): """ Model fitted by PrefixSpan @@ -114,7 +111,7 @@ class PrefixSpanModel(JavaModelWrapper): >>> rdd = sc.parallelize(data, 2) >>> model = PrefixSpan.train(rdd) >>> sorted(model.freqSequences().collect()) - [FreqSequence(sequence=[[u'a']], freq=3), FreqSequence(sequence=[[u'a'], [u'a']], freq=1), ... + [FreqSequence(sequence=[['a']], freq=3), FreqSequence(sequence=[['a'], ['a']], freq=1), ... .. versionadded:: 1.6.0 """ diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index cd09621b13b56..c1402fb98a50d 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -27,15 +27,6 @@ import array import struct -if sys.version >= '3': - basestring = str - xrange = range - import copyreg as copy_reg - long = int -else: - from itertools import izip as zip - import copy_reg - import numpy as np from pyspark import since @@ -49,13 +40,6 @@ 'QRDecomposition'] -if sys.version_info[:2] == (2, 7): - # speed up pickling array in Python 2.7 - def fast_pickle_array(ar): - return array.array, (ar.typecode, ar.tostring()) - copy_reg.pickle(array.array, fast_pickle_array) - - # Check whether we have SciPy. MLlib works without it too, but if we have it, some methods, # such as _dot and _serialize_double_vector, start to support scipy.sparse matrices. @@ -70,7 +54,7 @@ def fast_pickle_array(ar): def _convert_to_vector(l): if isinstance(l, Vector): return l - elif type(l) in (array.array, np.array, np.ndarray, list, tuple, xrange): + elif type(l) in (array.array, np.array, np.ndarray, list, tuple, range): return DenseVector(l) elif _have_scipy and scipy.sparse.issparse(l): assert l.shape[1] == 1, "Expected column vector" @@ -104,7 +88,7 @@ def _vector_size(v): """ if isinstance(v, Vector): return len(v) - elif type(v) in (array.array, list, tuple, xrange): + elif type(v) in (array.array, list, tuple, range): return len(v) elif type(v) == np.ndarray: if v.ndim == 1 or (v.ndim == 2 and v.shape[1] == 1): @@ -459,7 +443,7 @@ def __eq__(self, other): elif isinstance(other, SparseVector): if len(self) != other.size: return False - return Vectors._equals(list(xrange(len(self))), self.array, other.indices, other.values) + return Vectors._equals(list(range(len(self))), self.array, other.indices, other.values) return False def __ne__(self, other): @@ -556,7 +540,7 @@ def __init__(self, size, *args): self.indices = np.array(args[0], dtype=np.int32) self.values = np.array(args[1], dtype=np.float64) assert len(self.indices) == len(self.values), "index and value arrays not same length" - for i in xrange(len(self.indices) - 1): + for i in range(len(self.indices) - 1): if self.indices[i] >= self.indices[i + 1]: raise TypeError( "Indices %s and %s are not strictly increasing" @@ -788,7 +772,7 @@ def __repr__(self): inds = self.indices vals = self.values entries = ", ".join(["{0}: {1}".format(inds[i], _format_float(vals[i])) - for i in xrange(len(inds))]) + for i in range(len(inds))]) return "SparseVector({0}, {{{1}}})".format(self.size, entries) def __eq__(self, other): @@ -798,7 +782,7 @@ def __eq__(self, other): elif isinstance(other, DenseVector): if self.size != len(other): return False - return Vectors._equals(self.indices, self.values, list(xrange(len(other))), other.array) + return Vectors._equals(self.indices, self.values, list(range(len(other))), other.array) return False def __getitem__(self, index): @@ -880,7 +864,7 @@ def dense(*elements): >>> Vectors.dense(1.0, 2.0) DenseVector([1.0, 2.0]) """ - if len(elements) == 1 and not isinstance(elements[0], (float, int, long)): + if len(elements) == 1 and not isinstance(elements[0], (float, int)): # it's list, numpy.array or other iterable object. elements = elements[0] return DenseVector(elements) @@ -1279,7 +1263,7 @@ def toArray(self): Return an numpy.ndarray """ A = np.zeros((self.numRows, self.numCols), dtype=np.float64, order='F') - for k in xrange(self.colPtrs.size - 1): + for k in range(self.colPtrs.size - 1): startptr = self.colPtrs[k] endptr = self.colPtrs[k + 1] if self.isTransposed: diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 56701758c89c9..603d31d3d7b26 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -21,9 +21,6 @@ import sys -if sys.version >= '3': - long = int - from py4j.java_gateway import JavaObject from pyspark import RDD, since @@ -95,9 +92,9 @@ def __init__(self, rows, numRows=0, numCols=0): """ if isinstance(rows, RDD): rows = rows.map(_convert_to_vector) - java_matrix = callMLlibFunc("createRowMatrix", rows, long(numRows), int(numCols)) + java_matrix = callMLlibFunc("createRowMatrix", rows, int(numRows), int(numCols)) elif isinstance(rows, DataFrame): - java_matrix = callMLlibFunc("createRowMatrix", rows, long(numRows), int(numCols)) + java_matrix = callMLlibFunc("createRowMatrix", rows, int(numRows), int(numCols)) elif (isinstance(rows, JavaObject) and rows.getClass().getSimpleName() == "RowMatrix"): java_matrix = rows @@ -439,13 +436,13 @@ class IndexedRow(object): """ Represents a row of an IndexedRowMatrix. - Just a wrapper over a (long, vector) tuple. + Just a wrapper over a (int, vector) tuple. :param index: The index for the given row. :param vector: The row in the matrix at the given index. """ def __init__(self, index, vector): - self.index = long(index) + self.index = int(index) self.vector = _convert_to_vector(vector) def __repr__(self): @@ -465,8 +462,8 @@ class IndexedRowMatrix(DistributedMatrix): """ Represents a row-oriented distributed Matrix with indexed rows. - :param rows: An RDD of IndexedRows or (long, vector) tuples or a DataFrame consisting of a - long typed column of indices and a vector typed column. + :param rows: An RDD of IndexedRows or (int, vector) tuples or a DataFrame consisting of a + int typed column of indices and a vector typed column. :param numRows: Number of rows in the matrix. A non-positive value means unknown, at which point the number of rows will be determined by the max row @@ -510,14 +507,14 @@ def __init__(self, rows, numRows=0, numCols=0): # both be easily serialized. We will convert back to # IndexedRows on the Scala side. java_matrix = callMLlibFunc("createIndexedRowMatrix", rows.toDF(), - long(numRows), int(numCols)) + int(numRows), int(numCols)) elif isinstance(rows, DataFrame): - java_matrix = callMLlibFunc("createIndexedRowMatrix", rows, long(numRows), int(numCols)) + java_matrix = callMLlibFunc("createIndexedRowMatrix", rows, int(numRows), int(numCols)) elif (isinstance(rows, JavaObject) and rows.getClass().getSimpleName() == "IndexedRowMatrix"): java_matrix = rows else: - raise TypeError("rows should be an RDD of IndexedRows or (long, vector) tuples, " + raise TypeError("rows should be an RDD of IndexedRows or (int, vector) tuples, " "got %s" % type(rows)) self._java_matrix_wrapper = JavaModelWrapper(java_matrix) @@ -731,15 +728,15 @@ class MatrixEntry(object): """ Represents an entry of a CoordinateMatrix. - Just a wrapper over a (long, long, float) tuple. + Just a wrapper over a (int, int, float) tuple. :param i: The row index of the matrix. :param j: The column index of the matrix. :param value: The (i, j)th entry of the matrix, as a float. """ def __init__(self, i, j, value): - self.i = long(i) - self.j = long(j) + self.i = int(i) + self.j = int(j) self.value = float(value) def __repr__(self): @@ -760,7 +757,7 @@ class CoordinateMatrix(DistributedMatrix): Represents a matrix in coordinate format. :param entries: An RDD of MatrixEntry inputs or - (long, long, float) tuples. + (int, int, float) tuples. :param numRows: Number of rows in the matrix. A non-positive value means unknown, at which point the number of rows will be determined by the max row @@ -804,13 +801,13 @@ def __init__(self, entries, numRows=0, numCols=0): # each be easily serialized. We will convert back to # MatrixEntry inputs on the Scala side. java_matrix = callMLlibFunc("createCoordinateMatrix", entries.toDF(), - long(numRows), long(numCols)) + int(numRows), int(numCols)) elif (isinstance(entries, JavaObject) and entries.getClass().getSimpleName() == "CoordinateMatrix"): java_matrix = entries else: raise TypeError("entries should be an RDD of MatrixEntry entries or " - "(long, long, float) tuples, got %s" % type(entries)) + "(int, int, float) tuples, got %s" % type(entries)) self._java_matrix_wrapper = JavaModelWrapper(java_matrix) @@ -1044,7 +1041,7 @@ def __init__(self, blocks, rowsPerBlock, colsPerBlock, numRows=0, numCols=0): # the Scala side. java_matrix = callMLlibFunc("createBlockMatrix", blocks.toDF(), int(rowsPerBlock), int(colsPerBlock), - long(numRows), long(numCols)) + int(numRows), int(numCols)) elif (isinstance(blocks, JavaObject) and blocks.getClass().getSimpleName() == "BlockMatrix"): java_matrix = blocks diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py index 7250eab6705a7..56444c152f0ba 100644 --- a/python/pyspark/mllib/stat/KernelDensity.py +++ b/python/pyspark/mllib/stat/KernelDensity.py @@ -15,11 +15,6 @@ # limitations under the License. # -import sys - -if sys.version > '3': - xrange = range - import numpy as np from pyspark.mllib.common import callMLlibFunc diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index d49f741a2f44a..43454ba5187dd 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -16,10 +16,8 @@ # import sys -if sys.version >= '3': - basestring = str -from pyspark.rdd import RDD, ignore_unicode_prefix +from pyspark.rdd import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import Matrix, _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -157,7 +155,6 @@ def corr(x, y=None, method=None): return callMLlibFunc("corr", x.map(float), y.map(float), method) @staticmethod - @ignore_unicode_prefix def chiSqTest(observed, expected=None): """ If `observed` is Vector, conduct Pearson's chi-squared goodness @@ -199,9 +196,9 @@ def chiSqTest(observed, expected=None): >>> print(round(pearson.pValue, 4)) 0.8187 >>> pearson.method - u'pearson' + 'pearson' >>> pearson.nullHypothesis - u'observed follows the same distribution as expected.' + 'observed follows the same distribution as expected.' >>> observed = Vectors.dense([21, 38, 43, 80]) >>> expected = Vectors.dense([3, 5, 7, 20]) @@ -242,7 +239,6 @@ def chiSqTest(observed, expected=None): return ChiSqTestResult(jmodel) @staticmethod - @ignore_unicode_prefix def kolmogorovSmirnovTest(data, distName="norm", *params): """ Performs the Kolmogorov-Smirnov (KS) test for data sampled from @@ -282,7 +278,7 @@ def kolmogorovSmirnovTest(data, distName="norm", *params): >>> print(round(ksmodel.statistic, 3)) 0.175 >>> ksmodel.nullHypothesis - u'Sample follows theoretical distribution' + 'Sample follows theoretical distribution' >>> data = sc.parallelize([2.0, 3.0, 4.0]) >>> ksmodel = kstest(data, "norm", 3.0, 1.0) @@ -293,7 +289,7 @@ def kolmogorovSmirnovTest(data, distName="norm", *params): """ if not isinstance(data, RDD): raise TypeError("data should be an RDD, got %s." % type(data)) - if not isinstance(distName, basestring): + if not isinstance(distName, str): raise TypeError("distName should be a string, got %s." % type(distName)) params = [float(param) for param in params] diff --git a/python/pyspark/mllib/tests/test_linalg.py b/python/pyspark/mllib/tests/test_linalg.py index 312730e8aff8b..21c2bb422a3c3 100644 --- a/python/pyspark/mllib/tests/test_linalg.py +++ b/python/pyspark/mllib/tests/test_linalg.py @@ -31,9 +31,6 @@ from pyspark.testing.mllibutils import MLlibTestCase from pyspark.testing.utils import have_scipy -if sys.version >= '3': - long = int - class VectorTests(MLlibTestCase): @@ -447,7 +444,7 @@ def test_row_matrix_from_dataframe(self): def test_indexed_row_matrix_from_dataframe(self): from pyspark.sql.utils import IllegalArgumentException - df = self.spark.createDataFrame([Row(long(0), Vectors.dense(1))]) + df = self.spark.createDataFrame([Row(int(0), Vectors.dense(1))]) matrix = IndexedRowMatrix(df) self.assertEqual(matrix.numRows(), 1) self.assertEqual(matrix.numCols(), 1) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 2d8df461acf9f..e05dfdb953ceb 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import absolute_import - import sys import random diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index f0f9cda4672b1..a0be29a82e3dc 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -18,10 +18,6 @@ import sys import numpy as np -if sys.version > '3': - xrange = range - basestring = str - from pyspark import SparkContext, since from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector @@ -46,7 +42,7 @@ def _parse_libsvm_line(line): nnz = len(items) - 1 indices = np.zeros(nnz, dtype=np.int32) values = np.zeros(nnz) - for i in xrange(nnz): + for i in range(nnz): index, value = items[1 + i].split(":") indices[i] = int(index) - 1 values[i] = float(value) @@ -61,10 +57,10 @@ def _convert_labeled_point_to_libsvm(p): v = _convert_to_vector(p.features) if isinstance(v, SparseVector): nnz = len(v.indices) - for i in xrange(nnz): + for i in range(nnz): items.append(str(v.indices[i] + 1) + ":" + str(v.values[i])) else: - for i in xrange(len(v)): + for i in range(len(v)): items.append(str(i + 1) + ":" + str(v[i])) return " ".join(items) @@ -396,8 +392,8 @@ def save(self, sc, path): """Save this model to the given path.""" if not isinstance(sc, SparkContext): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) + if not isinstance(path, str): + raise TypeError("path should be a string, got type %s" % type(path)) self._java_model.save(sc._jsc.sc(), path) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index db0c1971cd2fe..437b2c446529a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -33,15 +33,10 @@ from functools import reduce from math import sqrt, log, isinf, isnan, pow, ceil -if sys.version > '3': - basestring = unicode = str -else: - from itertools import imap as map, ifilter as filter - from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, NoOpSerializer, \ CartesianDeserializer, CloudPickleSerializer, PairDeserializer, PickleSerializer, \ - UTF8Deserializer, pack_long, read_int, write_int + pack_long, read_int, write_int from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -93,7 +88,7 @@ def portable_hash(x): 219750521 """ - if sys.version_info >= (3, 2, 3) and 'PYTHONHASHSEED' not in os.environ: + if 'PYTHONHASHSEED' not in os.environ: raise Exception("Randomness of hash of string should be disabled via PYTHONHASHSEED") if x is None: @@ -204,19 +199,6 @@ def __del__(self): return iter(PyLocalIterable(sock_info, serializer)) -def ignore_unicode_prefix(f): - """ - Ignore the 'u' prefix of string in doc tests, to make it works - in both python 2 and 3 - """ - if sys.version >= '3': - # the representation of unicode string in Python 3 does not have prefix 'u', - # so remove the prefix 'u' for doc tests - literal_re = re.compile(r"(\W|^)[uU](['])", re.UNICODE) - f.__doc__ = literal_re.sub(r'\1\2', f.__doc__) - return f - - class Partitioner(object): def __init__(self, numPartitions, partitionFunc): self.numPartitions = numPartitions @@ -797,13 +779,12 @@ def groupBy(self, f, numPartitions=None, partitionFunc=portable_hash): """ return self.map(lambda x: (f(x), x)).groupByKey(numPartitions, partitionFunc) - @ignore_unicode_prefix def pipe(self, command, env=None, checkCode=False): """ Return an RDD created by piping elements to a forked external process. >>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect() - [u'1', u'2', u'', u'3'] + ['1', '2', '', '3'] :param checkCode: whether or not to check the return value of the shell command. """ @@ -816,7 +797,7 @@ def func(iterator): def pipe_objs(out): for obj in iterator: - s = unicode(obj).rstrip('\n') + '\n' + s = str(obj).rstrip('\n') + '\n' out.write(s.encode('utf-8')) out.close() Thread(target=pipe_objs, args=[pipe.stdin]).start() @@ -1591,7 +1572,6 @@ def saveAsPickleFile(self, path, batchSize=10): ser = BatchedSerializer(PickleSerializer(), batchSize) self._reserialize(ser)._jrdd.saveAsObjectFile(path) - @ignore_unicode_prefix def saveAsTextFile(self, path, compressionCodecClass=None): """ Save this RDD as a text file, using string representations of elements. @@ -1625,13 +1605,13 @@ def saveAsTextFile(self, path, compressionCodecClass=None): >>> from fileinput import input, hook_compressed >>> result = sorted(input(glob(tempFile3.name + "/part*.gz"), openhook=hook_compressed)) >>> b''.join(result).decode('utf-8') - u'bar\\nfoo\\n' + 'bar\\nfoo\\n' """ def func(split, iterator): for x in iterator: - if not isinstance(x, (unicode, bytes)): - x = unicode(x) - if isinstance(x, unicode): + if not isinstance(x, (str, bytes)): + x = str(x) + if isinstance(x, str): x = x.encode("utf-8") yield x keyed = self.mapPartitionsWithIndex(func) @@ -2281,14 +2261,13 @@ def name(self): if n: return n - @ignore_unicode_prefix def setName(self, name): """ Assign a name to this RDD. >>> rdd1 = sc.parallelize([1, 2]) >>> rdd1.setName('RDD1').name() - u'RDD1' + 'RDD1' """ self._jrdd.setName(name) return self diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py index c867b51877ffe..cd2a59513bb17 100644 --- a/python/pyspark/resultiterable.py +++ b/python/pyspark/resultiterable.py @@ -15,10 +15,7 @@ # limitations under the License. # -try: - from collections.abc import Iterable -except ImportError: - from collections import Iterable +from collections.abc import Iterable __all__ = ["ResultIterable"] diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 49b7cb4546676..80ce9b8408d4e 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -58,18 +58,11 @@ import collections import zlib import itertools - -if sys.version < '3': - import cPickle as pickle - from itertools import izip as zip, imap as map -else: - import pickle - basestring = unicode = str - xrange = range +import pickle pickle_protocol = pickle.HIGHEST_PROTOCOL from pyspark import cloudpickle -from pyspark.util import _exception_message, print_exec +from pyspark.util import print_exec __all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"] @@ -132,11 +125,6 @@ class FramedSerializer(Serializer): where `length` is a 32-bit integer and data is `length` bytes. """ - def __init__(self): - # On Python 2.6, we can't write bytearrays to streams, so we need to convert them - # to strings first. Check if the version number is that old. - self._only_write_strings = sys.version_info[0:2] <= (2, 6) - def dump_stream(self, iterator, stream): for obj in iterator: self._write_with_length(obj, stream) @@ -155,10 +143,7 @@ def _write_with_length(self, obj, stream): if len(serialized) > (1 << 31): raise ValueError("can not serialize object larger than 2G") write_int(len(serialized), stream) - if self._only_write_strings: - stream.write(str(serialized)) - else: - stream.write(serialized) + stream.write(serialized) def _read_with_length(self, stream): length = read_int(stream) @@ -204,7 +189,7 @@ def _batched(self, iterator): yield list(iterator) elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"): n = len(iterator) - for i in xrange(0, n, self.batchSize): + for i in range(0, n, self.batchSize): yield iterator[i: i + self.batchSize] else: items = [] @@ -395,23 +380,8 @@ def _copy_func(f): return types.FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__) - def _kwdefaults(f): - # __kwdefaults__ contains the default values of keyword-only arguments which are - # introduced from Python 3. The possible cases for __kwdefaults__ in namedtuple - # are as below: - # - # - Does not exist in Python 2. - # - Returns None in <= Python 3.5.x. - # - Returns a dictionary containing the default values to the keys from Python 3.6.x - # (See https://bugs.python.org/issue25628). - kargs = getattr(f, "__kwdefaults__", None) - if kargs is None: - return {} - else: - return kargs - _old_namedtuple = _copy_func(collections.namedtuple) - _old_namedtuple_kwdefaults = _kwdefaults(collections.namedtuple) + _old_namedtuple_kwdefaults = collections.namedtuple.__kwdefaults__ def namedtuple(*args, **kwargs): for k, v in _old_namedtuple_kwdefaults.items(): @@ -453,12 +423,8 @@ class PickleSerializer(FramedSerializer): def dumps(self, obj): return pickle.dumps(obj, pickle_protocol) - if sys.version >= '3': - def loads(self, obj, encoding="bytes"): - return pickle.loads(obj, encoding=encoding) - else: - def loads(self, obj, encoding=None): - return pickle.loads(obj) + def loads(self, obj, encoding="bytes"): + return pickle.loads(obj, encoding=encoding) class CloudPickleSerializer(PickleSerializer): @@ -469,7 +435,7 @@ def dumps(self, obj): except pickle.PickleError: raise except Exception as e: - emsg = _exception_message(e) + emsg = str(e) if "'i' format requires" in emsg: msg = "Object too large to serialize: %s" % emsg else: diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 65e3bdbc05ce8..cde163bd2d73d 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -26,11 +26,8 @@ import platform import warnings -import py4j - -from pyspark import SparkConf from pyspark.context import SparkContext -from pyspark.sql import SparkSession, SQLContext +from pyspark.sql import SparkSession if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index c28cb8c3b9cbe..af32469e82b43 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -39,9 +39,6 @@ - :class:`pyspark.sql.Window` For working with window functions. """ -from __future__ import absolute_import - - from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration from pyspark.sql.session import SparkSession diff --git a/python/pyspark/sql/avro/functions.py b/python/pyspark/sql/avro/functions.py index ed62a72d6c8fb..974412ee4efea 100644 --- a/python/pyspark/sql/avro/functions.py +++ b/python/pyspark/sql/avro/functions.py @@ -21,12 +21,10 @@ from pyspark import since, SparkContext -from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column from pyspark.util import _print_missing_jar -@ignore_unicode_prefix @since(3.0) def from_avro(data, jsonFormatSchema, options={}): """ @@ -45,7 +43,7 @@ def from_avro(data, jsonFormatSchema, options={}): >>> from pyspark.sql import Row >>> from pyspark.sql.avro.functions import from_avro, to_avro - >>> data = [(1, Row(name='Alice', age=2))] + >>> data = [(1, Row(age=2, name='Alice'))] >>> df = spark.createDataFrame(data, ("key", "value")) >>> avroDf = df.select(to_avro(df.value).alias("avro")) >>> avroDf.collect() @@ -55,7 +53,7 @@ def from_avro(data, jsonFormatSchema, options={}): ... "fields":[{"name":"age","type":["long","null"]}, ... {"name":"name","type":["string","null"]}]},"null"]}]}''' >>> avroDf.select(from_avro(avroDf.avro, jsonFormatSchema).alias("value")).collect() - [Row(value=Row(avro=Row(age=2, name=u'Alice')))] + [Row(value=Row(avro=Row(age=2, name='Alice')))] """ sc = SparkContext._active_spark_context @@ -69,7 +67,6 @@ def from_avro(data, jsonFormatSchema, options={}): return Column(jc) -@ignore_unicode_prefix @since(3.0) def to_avro(data, jsonFormatSchema=""): """ diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 974251f63b37a..25fc696dac051 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -20,10 +20,8 @@ from collections import namedtuple from pyspark import since -from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.sql.dataframe import DataFrame -from pyspark.sql.udf import UserDefinedFunction -from pyspark.sql.types import IntegerType, StringType, StructType +from pyspark.sql.types import StructType Database = namedtuple("Database", "name description locationUri") @@ -44,19 +42,16 @@ def __init__(self, sparkSession): self._jsparkSession = sparkSession._jsparkSession self._jcatalog = sparkSession._jsparkSession.catalog() - @ignore_unicode_prefix @since(2.0) def currentDatabase(self): """Returns the current default database in this session.""" return self._jcatalog.currentDatabase() - @ignore_unicode_prefix @since(2.0) def setCurrentDatabase(self, dbName): """Sets the current default database in this session.""" return self._jcatalog.setCurrentDatabase(dbName) - @ignore_unicode_prefix @since(2.0) def listDatabases(self): """Returns a list of databases available across all sessions.""" @@ -70,7 +65,6 @@ def listDatabases(self): locationUri=jdb.locationUri())) return databases - @ignore_unicode_prefix @since(2.0) def listTables(self, dbName=None): """Returns a list of tables/views in the specified database. @@ -92,7 +86,6 @@ def listTables(self, dbName=None): isTemporary=jtable.isTemporary())) return tables - @ignore_unicode_prefix @since(2.0) def listFunctions(self, dbName=None): """Returns a list of functions registered in the specified database. @@ -113,7 +106,6 @@ def listFunctions(self, dbName=None): isTemporary=jfunction.isTemporary())) return functions - @ignore_unicode_prefix @since(2.0) def listColumns(self, tableName, dbName=None): """Returns a list of columns for the given table/view in the specified database. diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index ef4944c9121a4..bd4c35576214e 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -19,15 +19,8 @@ import json import warnings -if sys.version >= '3': - basestring = str - long = int - -from py4j.java_gateway import is_instance_of - from pyspark import copy_func, since from pyspark.context import SparkContext -from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.types import * __all__ = ["Column"] @@ -46,7 +39,7 @@ def _create_column_from_name(name): def _to_java_column(col): if isinstance(col, Column): jcol = col._jc - elif isinstance(col, basestring): + elif isinstance(col, str): jcol = _create_column_from_name(col) else: raise TypeError( @@ -359,7 +352,7 @@ def __iter__(self): :param other: string in line >>> df.filter(df.name.contains('o')).collect() - [Row(age=5, name=u'Bob')] + [Row(age=5, name='Bob')] """ _rlike_doc = """ SQL RLIKE expression (LIKE with Regex). Returns a boolean :class:`Column` based on a regex @@ -368,7 +361,7 @@ def __iter__(self): :param other: an extended regex expression >>> df.filter(df.name.rlike('ice$')).collect() - [Row(age=2, name=u'Alice')] + [Row(age=2, name='Alice')] """ _like_doc = """ SQL like expression. Returns a boolean :class:`Column` based on a SQL LIKE match. @@ -378,7 +371,7 @@ def __iter__(self): See :func:`rlike` for a regex version >>> df.filter(df.name.like('Al%')).collect() - [Row(age=2, name=u'Alice')] + [Row(age=2, name='Alice')] """ _startswith_doc = """ String starts with. Returns a boolean :class:`Column` based on a string match. @@ -386,7 +379,7 @@ def __iter__(self): :param other: string at start of line (do not use a regex `^`) >>> df.filter(df.name.startswith('Al')).collect() - [Row(age=2, name=u'Alice')] + [Row(age=2, name='Alice')] >>> df.filter(df.name.startswith('^Al')).collect() [] """ @@ -396,18 +389,17 @@ def __iter__(self): :param other: string at end of line (do not use a regex `$`) >>> df.filter(df.name.endswith('ice')).collect() - [Row(age=2, name=u'Alice')] + [Row(age=2, name='Alice')] >>> df.filter(df.name.endswith('ice$')).collect() [] """ - contains = ignore_unicode_prefix(_bin_op("contains", _contains_doc)) - rlike = ignore_unicode_prefix(_bin_op("rlike", _rlike_doc)) - like = ignore_unicode_prefix(_bin_op("like", _like_doc)) - startswith = ignore_unicode_prefix(_bin_op("startsWith", _startswith_doc)) - endswith = ignore_unicode_prefix(_bin_op("endsWith", _endswith_doc)) + contains = _bin_op("contains", _contains_doc) + rlike = _bin_op("rlike", _rlike_doc) + like = _bin_op("like", _like_doc) + startswith = _bin_op("startsWith", _startswith_doc) + endswith = _bin_op("endsWith", _endswith_doc) - @ignore_unicode_prefix @since(1.3) def substr(self, startPos, length): """ @@ -417,7 +409,7 @@ def substr(self, startPos, length): :param length: length of the substring (int or Column) >>> df.select(df.name.substr(1, 3).alias("col")).collect() - [Row(col=u'Ali'), Row(col=u'Bob')] + [Row(col='Ali'), Row(col='Bob')] """ if type(startPos) != type(length): raise TypeError( @@ -435,7 +427,6 @@ def substr(self, startPos, length): raise TypeError("Unexpected type: %s" % type(startPos)) return Column(jc) - @ignore_unicode_prefix @since(1.5) def isin(self, *cols): """ @@ -443,9 +434,9 @@ def isin(self, *cols): expression is contained by the evaluated values of the arguments. >>> df[df.name.isin("Bob", "Mike")].collect() - [Row(age=5, name=u'Bob')] + [Row(age=5, name='Bob')] >>> df[df.age.isin([1, 2, 3])].collect() - [Row(age=2, name=u'Alice')] + [Row(age=2, name='Alice')] """ if len(cols) == 1 and isinstance(cols[0], (list, set)): cols = cols[0] @@ -461,7 +452,7 @@ def isin(self, *cols): >>> from pyspark.sql import Row >>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"]) >>> df.select(df.name).orderBy(df.name.asc()).collect() - [Row(name=u'Alice'), Row(name=u'Tom')] + [Row(name='Alice'), Row(name='Tom')] """ _asc_nulls_first_doc = """ Returns a sort expression based on ascending order of the column, and null values @@ -470,7 +461,7 @@ def isin(self, *cols): >>> from pyspark.sql import Row >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) >>> df.select(df.name).orderBy(df.name.asc_nulls_first()).collect() - [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')] + [Row(name=None), Row(name='Alice'), Row(name='Tom')] .. versionadded:: 2.4 """ @@ -481,7 +472,7 @@ def isin(self, *cols): >>> from pyspark.sql import Row >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) >>> df.select(df.name).orderBy(df.name.asc_nulls_last()).collect() - [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)] + [Row(name='Alice'), Row(name='Tom'), Row(name=None)] .. versionadded:: 2.4 """ @@ -491,7 +482,7 @@ def isin(self, *cols): >>> from pyspark.sql import Row >>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"]) >>> df.select(df.name).orderBy(df.name.desc()).collect() - [Row(name=u'Tom'), Row(name=u'Alice')] + [Row(name='Tom'), Row(name='Alice')] """ _desc_nulls_first_doc = """ Returns a sort expression based on the descending order of the column, and null values @@ -500,7 +491,7 @@ def isin(self, *cols): >>> from pyspark.sql import Row >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) >>> df.select(df.name).orderBy(df.name.desc_nulls_first()).collect() - [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')] + [Row(name=None), Row(name='Tom'), Row(name='Alice')] .. versionadded:: 2.4 """ @@ -511,37 +502,37 @@ def isin(self, *cols): >>> from pyspark.sql import Row >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) >>> df.select(df.name).orderBy(df.name.desc_nulls_last()).collect() - [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)] + [Row(name='Tom'), Row(name='Alice'), Row(name=None)] .. versionadded:: 2.4 """ - asc = ignore_unicode_prefix(_unary_op("asc", _asc_doc)) - asc_nulls_first = ignore_unicode_prefix(_unary_op("asc_nulls_first", _asc_nulls_first_doc)) - asc_nulls_last = ignore_unicode_prefix(_unary_op("asc_nulls_last", _asc_nulls_last_doc)) - desc = ignore_unicode_prefix(_unary_op("desc", _desc_doc)) - desc_nulls_first = ignore_unicode_prefix(_unary_op("desc_nulls_first", _desc_nulls_first_doc)) - desc_nulls_last = ignore_unicode_prefix(_unary_op("desc_nulls_last", _desc_nulls_last_doc)) + asc = _unary_op("asc", _asc_doc) + asc_nulls_first = _unary_op("asc_nulls_first", _asc_nulls_first_doc) + asc_nulls_last = _unary_op("asc_nulls_last", _asc_nulls_last_doc) + desc = _unary_op("desc", _desc_doc) + desc_nulls_first = _unary_op("desc_nulls_first", _desc_nulls_first_doc) + desc_nulls_last = _unary_op("desc_nulls_last", _desc_nulls_last_doc) _isNull_doc = """ True if the current expression is null. >>> from pyspark.sql import Row - >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df = spark.createDataFrame([Row(name='Tom', height=80), Row(name='Alice', height=None)]) >>> df.filter(df.height.isNull()).collect() - [Row(height=None, name=u'Alice')] + [Row(name='Alice', height=None)] """ _isNotNull_doc = """ True if the current expression is NOT null. >>> from pyspark.sql import Row - >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df = spark.createDataFrame([Row(name='Tom', height=80), Row(name='Alice', height=None)]) >>> df.filter(df.height.isNotNull()).collect() - [Row(height=80, name=u'Tom')] + [Row(name='Tom', height=80)] """ - isNull = ignore_unicode_prefix(_unary_op("isNull", _isNull_doc)) - isNotNull = ignore_unicode_prefix(_unary_op("isNotNull", _isNotNull_doc)) + isNull = _unary_op("isNull", _isNull_doc) + isNotNull = _unary_op("isNotNull", _isNotNull_doc) @since(1.3) def alias(self, *alias, **kwargs): @@ -581,17 +572,16 @@ def alias(self, *alias, **kwargs): name = copy_func(alias, sinceversion=2.0, doc=":func:`name` is an alias for :func:`alias`.") - @ignore_unicode_prefix @since(1.3) def cast(self, dataType): """ Convert the column into type ``dataType``. >>> df.select(df.age.cast("string").alias('ages')).collect() - [Row(ages=u'2'), Row(ages=u'5')] + [Row(ages='2'), Row(ages='5')] >>> df.select(df.age.cast(StringType()).alias('ages')).collect() - [Row(ages=u'2'), Row(ages=u'5')] + [Row(ages='2'), Row(ages='5')] """ - if isinstance(dataType, basestring): + if isinstance(dataType, str): jc = self._jc.cast(dataType) elif isinstance(dataType, DataType): from pyspark.sql import SparkSession diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index 71ea1631718f1..eab084a1faddf 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -18,10 +18,6 @@ import sys from pyspark import since, _NoValue -from pyspark.rdd import ignore_unicode_prefix - -if sys.version_info[0] >= 3: - basestring = str class RuntimeConfig(object): @@ -34,13 +30,11 @@ def __init__(self, jconf): """Create a new RuntimeConfig that wraps the underlying JVM object.""" self._jconf = jconf - @ignore_unicode_prefix @since(2.0) def set(self, key, value): """Sets the given Spark runtime configuration property.""" self._jconf.set(key, value) - @ignore_unicode_prefix @since(2.0) def get(self, key, default=_NoValue): """Returns the value of Spark runtime configuration property for the given key, @@ -54,7 +48,6 @@ def get(self, key, default=_NoValue): self._checkType(default, "default") return self._jconf.get(key, default) - @ignore_unicode_prefix @since(2.0) def unset(self, key): """Resets the configuration property for the given key.""" @@ -62,11 +55,10 @@ def unset(self, key): def _checkType(self, obj, identifier): """Assert that an object is of type str.""" - if not isinstance(obj, basestring): + if not isinstance(obj, str): raise TypeError("expected %s '%s' to be a string (was '%s')" % (identifier, obj, type(obj).__name__)) - @ignore_unicode_prefix @since(2.4) def isModifiable(self, key): """Indicates whether the configuration property with the given key diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 956343a2310b8..7fbcf85cb1d50 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -15,15 +15,10 @@ # limitations under the License. # -from __future__ import print_function import sys import warnings -if sys.version >= '3': - basestring = unicode = str - from pyspark import since, _NoValue -from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader @@ -52,7 +47,6 @@ class SQLContext(object): _instantiatedContext = None - @ignore_unicode_prefix def __init__(self, sparkContext, sparkSession=None, jsqlContext=None): """Creates a new SQLContext. @@ -70,7 +64,7 @@ def __init__(self, sparkContext, sparkSession=None, jsqlContext=None): [Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \ dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] >>> df.rdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() - [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] + [(1, 'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ warnings.warn( "Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.", @@ -142,7 +136,6 @@ def setConf(self, key, value): """ self.sparkSession.conf.set(key, value) - @ignore_unicode_prefix @since(1.3) def getConf(self, key, defaultValue=_NoValue): """Returns the value of Spark SQL configuration property for the given key. @@ -152,12 +145,12 @@ def getConf(self, key, defaultValue=_NoValue): the system default value. >>> sqlContext.getConf("spark.sql.shuffle.partitions") - u'200' - >>> sqlContext.getConf("spark.sql.shuffle.partitions", u"10") - u'10' - >>> sqlContext.setConf("spark.sql.shuffle.partitions", u"50") - >>> sqlContext.getConf("spark.sql.shuffle.partitions", u"10") - u'50' + '200' + >>> sqlContext.getConf("spark.sql.shuffle.partitions", "10") + '10' + >>> sqlContext.setConf("spark.sql.shuffle.partitions", "50") + >>> sqlContext.getConf("spark.sql.shuffle.partitions", "10") + '50' """ return self.sparkSession.conf.get(key, defaultValue) @@ -229,7 +222,6 @@ def _inferSchema(self, rdd, samplingRatio=None): return self.sparkSession._inferSchema(rdd, samplingRatio) @since(1.3) - @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): """ Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. @@ -274,27 +266,27 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr >>> l = [('Alice', 1)] >>> sqlContext.createDataFrame(l).collect() - [Row(_1=u'Alice', _2=1)] + [Row(_1='Alice', _2=1)] >>> sqlContext.createDataFrame(l, ['name', 'age']).collect() - [Row(name=u'Alice', age=1)] + [Row(name='Alice', age=1)] >>> d = [{'name': 'Alice', 'age': 1}] >>> sqlContext.createDataFrame(d).collect() - [Row(age=1, name=u'Alice')] + [Row(age=1, name='Alice')] >>> rdd = sc.parallelize(l) >>> sqlContext.createDataFrame(rdd).collect() - [Row(_1=u'Alice', _2=1)] + [Row(_1='Alice', _2=1)] >>> df = sqlContext.createDataFrame(rdd, ['name', 'age']) >>> df.collect() - [Row(name=u'Alice', age=1)] + [Row(name='Alice', age=1)] >>> from pyspark.sql import Row >>> Person = Row('name', 'age') >>> person = rdd.map(lambda r: Person(*r)) >>> df2 = sqlContext.createDataFrame(person) >>> df2.collect() - [Row(name=u'Alice', age=1)] + [Row(name='Alice', age=1)] >>> from pyspark.sql.types import * >>> schema = StructType([ @@ -302,15 +294,15 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr ... StructField("age", IntegerType(), True)]) >>> df3 = sqlContext.createDataFrame(rdd, schema) >>> df3.collect() - [Row(name=u'Alice', age=1)] + [Row(name='Alice', age=1)] >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP - [Row(name=u'Alice', age=1)] + [Row(name='Alice', age=1)] >>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP [Row(0=1, 1=2)] >>> sqlContext.createDataFrame(rdd, "a: string, b: int").collect() - [Row(a=u'Alice', b=1)] + [Row(a='Alice', b=1)] >>> rdd = rdd.map(lambda row: row[1]) >>> sqlContext.createDataFrame(rdd, "int").collect() [Row(value=1)] @@ -358,7 +350,6 @@ def createExternalTable(self, tableName, path=None, source=None, schema=None, ** return self.sparkSession.catalog.createExternalTable( tableName, path, source, schema, **options) - @ignore_unicode_prefix @since(1.0) def sql(self, sqlQuery): """Returns a :class:`DataFrame` representing the result of the given query. @@ -368,7 +359,7 @@ def sql(self, sqlQuery): >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() - [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] + [Row(f1=1, f2='row1'), Row(f1=2, f2='row2'), Row(f1=3, f2='row3')] """ return self.sparkSession.sql(sqlQuery) @@ -385,7 +376,6 @@ def table(self, tableName): """ return self.sparkSession.table(tableName) - @ignore_unicode_prefix @since(1.3) def tables(self, dbName=None): """Returns a :class:`DataFrame` containing names of tables in the given database. @@ -401,7 +391,7 @@ def tables(self, dbName=None): >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.tables() >>> df2.filter("tableName = 'table1'").first() - Row(database=u'', tableName=u'table1', isTemporary=True) + Row(database='', tableName='table1', isTemporary=True) """ if dbName is None: return DataFrame(self._ssql_ctx.tables(), self) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3ad899bcc3670..023fbeabcbabc 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -17,21 +17,12 @@ import sys import random - -if sys.version >= '3': - basestring = unicode = str - long = int - from functools import reduce - from html import escape as html_escape -else: - from itertools import imap as map - from cgi import escape as html_escape - import warnings +from functools import reduce +from html import escape as html_escape from pyspark import copy_func, since, _NoValue -from pyspark.rdd import RDD, _load_from_socket, _local_iterator_from_socket, \ - ignore_unicode_prefix +from pyspark.rdd import RDD, _load_from_socket, _local_iterator_from_socket from pyspark.serializers import BatchedSerializer, PickleSerializer, \ UTF8Deserializer from pyspark.storagelevel import StorageLevel @@ -109,7 +100,6 @@ def stat(self): """ return DataFrameStatFunctions(self) - @ignore_unicode_prefix @since(1.3) def toJSON(self, use_unicode=True): """Converts a :class:`DataFrame` into a :class:`RDD` of string. @@ -117,7 +107,7 @@ def toJSON(self, use_unicode=True): Each row is turned into a JSON document as one element in the returned RDD. >>> df.toJSON().first() - u'{"age":2,"name":"Alice"}' + '{"age":2,"name":"Alice"}' """ rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) @@ -330,11 +320,11 @@ def explain(self, extended=None, mode=None): # For the case when extended is mode: # df.explain("formatted") - is_extended_as_mode = isinstance(extended, basestring) and mode is None + is_extended_as_mode = isinstance(extended, str) and mode is None # For the mode specified: # df.explain(mode="formatted") - is_mode_case = extended is None and isinstance(mode, basestring) + is_mode_case = extended is None and isinstance(mode, str) if not (is_no_argument or is_extended_case or is_extended_as_mode or is_mode_case): argtypes = [ @@ -568,7 +558,7 @@ def hint(self, name, *parameters): if not isinstance(name, str): raise TypeError("name should be provided as str, got {0}".format(type(name))) - allowed_types = (basestring, list, float, int) + allowed_types = (str, list, float, int) for p in parameters: if not isinstance(p, allowed_types): raise TypeError( @@ -587,19 +577,17 @@ def count(self): """ return int(self._jdf.count()) - @ignore_unicode_prefix @since(1.3) def collect(self): """Returns all the records as a list of :class:`Row`. >>> df.collect() - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + [Row(age=2, name='Alice'), Row(age=5, name='Bob')] """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.collectToPython() return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) - @ignore_unicode_prefix @since(2.0) def toLocalIterator(self, prefetchPartitions=False): """ @@ -612,36 +600,33 @@ def toLocalIterator(self, prefetchPartitions=False): before it is needed. >>> list(df.toLocalIterator()) - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + [Row(age=2, name='Alice'), Row(age=5, name='Bob')] """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.toPythonIterator(prefetchPartitions) return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer())) - @ignore_unicode_prefix @since(1.3) def limit(self, num): """Limits the result count to the number specified. >>> df.limit(1).collect() - [Row(age=2, name=u'Alice')] + [Row(age=2, name='Alice')] >>> df.limit(0).collect() [] """ jdf = self._jdf.limit(num) return DataFrame(jdf, self.sql_ctx) - @ignore_unicode_prefix @since(1.3) def take(self, num): """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. >>> df.take(2) - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + [Row(age=2, name='Alice'), Row(age=5, name='Bob')] """ return self.limit(num).collect() - @ignore_unicode_prefix @since(3.0) def tail(self, num): """ @@ -651,7 +636,7 @@ def tail(self, num): a very large ``num`` can crash the driver process with OutOfMemoryError. >>> df.tail(1) - [Row(age=5, name=u'Bob')] + [Row(age=5, name='Bob')] """ with SCCallSiteSync(self._sc): sock_info = self._jdf.tailToPython(num) @@ -818,7 +803,7 @@ def repartition(self, numPartitions, *cols): else: return DataFrame( self._jdf.repartition(numPartitions, self._jcols(*cols)), self.sql_ctx) - elif isinstance(numPartitions, (basestring, Column)): + elif isinstance(numPartitions, (str, Column)): cols = (numPartitions, ) + cols return DataFrame(self._jdf.repartition(self._jcols(*cols)), self.sql_ctx) else: @@ -869,7 +854,7 @@ def repartitionByRange(self, numPartitions, *cols): else: return DataFrame( self._jdf.repartitionByRange(numPartitions, self._jcols(*cols)), self.sql_ctx) - elif isinstance(numPartitions, (basestring, Column)): + elif isinstance(numPartitions, (str, Column)): cols = (numPartitions,) + cols return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)), self.sql_ctx) else: @@ -944,7 +929,7 @@ def sample(self, withReplacement=None, fraction=None, seed=None): fraction = withReplacement withReplacement = None - seed = long(seed) if seed is not None else None + seed = int(seed) if seed is not None else None args = [arg for arg in [withReplacement, fraction, seed] if arg is not None] jdf = self._jdf.sample(*args) return DataFrame(jdf, self.sql_ctx) @@ -978,15 +963,15 @@ def sampleBy(self, col, fractions, seed=None): .. versionchanged:: 3.0 Added sampling by a column of :class:`Column` """ - if isinstance(col, basestring): + if isinstance(col, str): col = Column(col) elif not isinstance(col, Column): raise ValueError("col must be a string or a column, but got %r" % type(col)) if not isinstance(fractions, dict): raise ValueError("fractions must be a dict but got %r" % type(fractions)) for k, v in fractions.items(): - if not isinstance(k, (float, int, long, basestring)): - raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) + if not isinstance(k, (float, int, str)): + raise ValueError("key must be float, int, or string, but got %r" % type(k)) fractions[k] = float(v) col = col._jc seed = seed if seed is not None else random.randint(0, sys.maxsize) @@ -1011,7 +996,7 @@ def randomSplit(self, weights, seed=None): if w < 0.0: raise ValueError("Weights must be positive. Found weight value: %s" % w) seed = seed if seed is not None else random.randint(0, sys.maxsize) - rdd_array = self._jdf.randomSplit(_to_list(self.sql_ctx._sc, weights), long(seed)) + rdd_array = self._jdf.randomSplit(_to_list(self.sql_ctx._sc, weights), int(seed)) return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] @property @@ -1052,12 +1037,11 @@ def colRegex(self, colName): | 3| +----+ """ - if not isinstance(colName, basestring): + if not isinstance(colName, str): raise ValueError("colName should be provided as string") jc = self._jdf.colRegex(colName) return Column(jc) - @ignore_unicode_prefix @since(1.3) def alias(self, alias): """Returns a new :class:`DataFrame` with an alias set. @@ -1070,12 +1054,11 @@ def alias(self, alias): >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') >>> joined_df.select("df_as1.name", "df_as2.name", "df_as2.age") \ .sort(desc("df_as1.name")).collect() - [Row(name=u'Bob', name=u'Bob', age=5), Row(name=u'Alice', name=u'Alice', age=2)] + [Row(name='Bob', name='Bob', age=5), Row(name='Alice', name='Alice', age=2)] """ - assert isinstance(alias, basestring), "alias should be a string" + assert isinstance(alias, str), "alias should be a string" return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) - @ignore_unicode_prefix @since(2.1) def crossJoin(self, other): """Returns the cartesian product with another :class:`DataFrame`. @@ -1083,18 +1066,17 @@ def crossJoin(self, other): :param other: Right side of the cartesian product. >>> df.select("age", "name").collect() - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + [Row(age=2, name='Alice'), Row(age=5, name='Bob')] >>> df2.select("name", "height").collect() - [Row(name=u'Tom', height=80), Row(name=u'Bob', height=85)] + [Row(name='Tom', height=80), Row(name='Bob', height=85)] >>> df.crossJoin(df2.select("height")).select("age", "name", "height").collect() - [Row(age=2, name=u'Alice', height=80), Row(age=2, name=u'Alice', height=85), - Row(age=5, name=u'Bob', height=80), Row(age=5, name=u'Bob', height=85)] + [Row(age=2, name='Alice', height=80), Row(age=2, name='Alice', height=85), + Row(age=5, name='Bob', height=80), Row(age=5, name='Bob', height=85)] """ jdf = self._jdf.crossJoin(other._jdf) return DataFrame(jdf, self.sql_ctx) - @ignore_unicode_prefix @since(1.3) def join(self, other, on=None, how=None): """Joins with another :class:`DataFrame`, using the given join expression. @@ -1113,27 +1095,27 @@ def join(self, other, on=None, how=None): >>> from pyspark.sql.functions import desc >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height) \ .sort(desc("name")).collect() - [Row(name=u'Bob', height=85), Row(name=u'Alice', height=None), Row(name=None, height=80)] + [Row(name='Bob', height=85), Row(name='Alice', height=None), Row(name=None, height=80)] >>> df.join(df2, 'name', 'outer').select('name', 'height').sort(desc("name")).collect() - [Row(name=u'Tom', height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] + [Row(name='Tom', height=80), Row(name='Bob', height=85), Row(name='Alice', height=None)] >>> cond = [df.name == df3.name, df.age == df3.age] >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect() - [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + [Row(name='Alice', age=2), Row(name='Bob', age=5)] >>> df.join(df2, 'name').select(df.name, df2.height).collect() - [Row(name=u'Bob', height=85)] + [Row(name='Bob', height=85)] >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect() - [Row(name=u'Bob', age=5)] + [Row(name='Bob', age=5)] """ if on is not None and not isinstance(on, list): on = [on] if on is not None: - if isinstance(on[0], basestring): + if isinstance(on[0], str): on = self._jseq(on) else: assert isinstance(on[0], Column), "on should be Column or list of Column" @@ -1147,7 +1129,7 @@ def join(self, other, on=None, how=None): how = "inner" if on is None: on = self._jseq([]) - assert isinstance(how, basestring), "how should be basestring" + assert isinstance(how, str), "how should be a string" jdf = self._jdf.join(other._jdf, on, how) return DataFrame(jdf, self.sql_ctx) @@ -1171,7 +1153,6 @@ def sortWithinPartitions(self, *cols, **kwargs): jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs)) return DataFrame(jdf, self.sql_ctx) - @ignore_unicode_prefix @since(1.3) def sort(self, *cols, **kwargs): """Returns a new :class:`DataFrame` sorted by the specified column(s). @@ -1182,18 +1163,18 @@ def sort(self, *cols, **kwargs): If a list is specified, length of the list must equal length of the `cols`. >>> df.sort(df.age.desc()).collect() - [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + [Row(age=5, name='Bob'), Row(age=2, name='Alice')] >>> df.sort("age", ascending=False).collect() - [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + [Row(age=5, name='Bob'), Row(age=2, name='Alice')] >>> df.orderBy(df.age.desc()).collect() - [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + [Row(age=5, name='Bob'), Row(age=2, name='Alice')] >>> from pyspark.sql.functions import * >>> df.sort(asc("age")).collect() - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + [Row(age=2, name='Alice'), Row(age=5, name='Bob')] >>> df.orderBy(desc("age"), "name").collect() - [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + [Row(age=5, name='Bob'), Row(age=2, name='Alice')] >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect() - [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + [Row(age=5, name='Bob'), Row(age=2, name='Alice')] """ jdf = self._jdf.sort(self._sort_cols(cols, kwargs)) return DataFrame(jdf, self.sql_ctx) @@ -1333,7 +1314,6 @@ def summary(self, *statistics): jdf = self._jdf.summary(self._jseq(statistics)) return DataFrame(jdf, self.sql_ctx) - @ignore_unicode_prefix @since(1.3) def head(self, n=None): """Returns the first ``n`` rows. @@ -1346,26 +1326,24 @@ def head(self, n=None): If n is 1, return a single Row. >>> df.head() - Row(age=2, name=u'Alice') + Row(age=2, name='Alice') >>> df.head(1) - [Row(age=2, name=u'Alice')] + [Row(age=2, name='Alice')] """ if n is None: rs = self.head(1) return rs[0] if rs else None return self.take(n) - @ignore_unicode_prefix @since(1.3) def first(self): """Returns the first row as a :class:`Row`. >>> df.first() - Row(age=2, name=u'Alice') + Row(age=2, name='Alice') """ return self.head() - @ignore_unicode_prefix @since(1.3) def __getitem__(self, item): """Returns the column as a :class:`Column`. @@ -1373,13 +1351,13 @@ def __getitem__(self, item): >>> df.select(df['age']).collect() [Row(age=2), Row(age=5)] >>> df[ ["name", "age"]].collect() - [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + [Row(name='Alice', age=2), Row(name='Bob', age=5)] >>> df[ df.age > 3 ].collect() - [Row(age=5, name=u'Bob')] + [Row(age=5, name='Bob')] >>> df[df[0] > 3].collect() - [Row(age=5, name=u'Bob')] + [Row(age=5, name='Bob')] """ - if isinstance(item, basestring): + if isinstance(item, str): jc = self._jdf.apply(item) return Column(jc) elif isinstance(item, Column): @@ -1405,7 +1383,6 @@ def __getattr__(self, name): jc = self._jdf.apply(name) return Column(jc) - @ignore_unicode_prefix @since(1.3) def select(self, *cols): """Projects a set of expressions and returns a new :class:`DataFrame`. @@ -1415,11 +1392,11 @@ def select(self, *cols): in the current :class:`DataFrame`. >>> df.select('*').collect() - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + [Row(age=2, name='Alice'), Row(age=5, name='Bob')] >>> df.select('name', 'age').collect() - [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + [Row(name='Alice', age=2), Row(name='Bob', age=5)] >>> df.select(df.name, (df.age + 10).alias('age')).collect() - [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] + [Row(name='Alice', age=12), Row(name='Bob', age=15)] """ jdf = self._jdf.select(self._jcols(*cols)) return DataFrame(jdf, self.sql_ctx) @@ -1438,7 +1415,6 @@ def selectExpr(self, *expr): jdf = self._jdf.selectExpr(self._jseq(expr)) return DataFrame(jdf, self.sql_ctx) - @ignore_unicode_prefix @since(1.3) def filter(self, condition): """Filters rows using the given condition. @@ -1449,16 +1425,16 @@ def filter(self, condition): or a string of SQL expression. >>> df.filter(df.age > 3).collect() - [Row(age=5, name=u'Bob')] + [Row(age=5, name='Bob')] >>> df.where(df.age == 2).collect() - [Row(age=2, name=u'Alice')] + [Row(age=2, name='Alice')] >>> df.filter("age > 3").collect() - [Row(age=5, name=u'Bob')] + [Row(age=5, name='Bob')] >>> df.where("age = 2").collect() - [Row(age=2, name=u'Alice')] + [Row(age=2, name='Alice')] """ - if isinstance(condition, basestring): + if isinstance(condition, str): jdf = self._jdf.filter(condition) elif isinstance(condition, Column): jdf = self._jdf.filter(condition._jc) @@ -1466,7 +1442,6 @@ def filter(self, condition): raise TypeError("condition should be string or Column") return DataFrame(jdf, self.sql_ctx) - @ignore_unicode_prefix @since(1.3) def groupBy(self, *cols): """Groups the :class:`DataFrame` using the specified columns, @@ -1481,11 +1456,11 @@ def groupBy(self, *cols): >>> df.groupBy().avg().collect() [Row(avg(age)=3.5)] >>> sorted(df.groupBy('name').agg({'age': 'mean'}).collect()) - [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] + [Row(name='Alice', avg(age)=2.0), Row(name='Bob', avg(age)=5.0)] >>> sorted(df.groupBy(df.name).avg().collect()) - [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] + [Row(name='Alice', avg(age)=2.0), Row(name='Bob', avg(age)=5.0)] >>> sorted(df.groupBy(['name', df.age]).count().collect()) - [Row(name=u'Alice', age=2, count=1), Row(name=u'Bob', age=5, count=1)] + [Row(name='Alice', age=2, count=1), Row(name='Bob', age=5, count=1)] """ jgd = self._jdf.groupBy(self._jcols(*cols)) from pyspark.sql.group import GroupedData @@ -1655,19 +1630,19 @@ def dropDuplicates(self, subset=None): ... Row(name='Alice', age=5, height=80), \\ ... Row(name='Alice', age=10, height=80)]).toDF() >>> df.dropDuplicates().show() - +---+------+-----+ - |age|height| name| - +---+------+-----+ - | 5| 80|Alice| - | 10| 80|Alice| - +---+------+-----+ + +-----+---+------+ + | name|age|height| + +-----+---+------+ + |Alice| 5| 80| + |Alice| 10| 80| + +-----+---+------+ >>> df.dropDuplicates(['name', 'height']).show() - +---+------+-----+ - |age|height| name| - +---+------+-----+ - | 5| 80|Alice| - +---+------+-----+ + +-----+---+------+ + | name|age|height| + +-----+---+------+ + |Alice| 5| 80| + +-----+---+------+ """ if subset is None: jdf = self._jdf.dropDuplicates() @@ -1700,7 +1675,7 @@ def dropna(self, how='any', thresh=None, subset=None): if subset is None: subset = self.columns - elif isinstance(subset, basestring): + elif isinstance(subset, str): subset = [subset] elif not isinstance(subset, (list, tuple)): raise ValueError("subset should be a list or tuple of column names") @@ -1715,11 +1690,11 @@ def fillna(self, value, subset=None): """Replace null values, alias for ``na.fill()``. :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other. - :param value: int, long, float, string, bool or dict. + :param value: int, float, string, bool or dict. Value to replace null values with. If the value is a dict, then `subset` is ignored and `value` must be a mapping from column name (string) to replacement value. The replacement value must be - an int, long, float, boolean, or string. + an int, float, boolean, or string. :param subset: optional list of column names to consider. Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, @@ -1754,13 +1729,13 @@ def fillna(self, value, subset=None): | 50| null|unknown| +---+------+-------+ """ - if not isinstance(value, (float, int, long, basestring, bool, dict)): - raise ValueError("value should be a float, int, long, string, bool or dict") + if not isinstance(value, (float, int, str, bool, dict)): + raise ValueError("value should be a float, int, string, bool or dict") # Note that bool validates isinstance(int), but we don't want to # convert bools to floats - if not isinstance(value, bool) and isinstance(value, (int, long)): + if not isinstance(value, bool) and isinstance(value, int): value = float(value) if isinstance(value, dict): @@ -1768,7 +1743,7 @@ def fillna(self, value, subset=None): elif subset is None: return DataFrame(self._jdf.na().fill(value), self.sql_ctx) else: - if isinstance(subset, basestring): + if isinstance(subset, str): subset = [subset] elif not isinstance(subset, (list, tuple)): raise ValueError("subset should be a list or tuple of column names") @@ -1787,12 +1762,12 @@ def replace(self, to_replace, value=_NoValue, subset=None): floating point representation. In case of conflicts (for example with `{42: -1, 42.0: 1}`) and arbitrary replacement will be used. - :param to_replace: bool, int, long, float, string, list or dict. + :param to_replace: bool, int, float, string, list or dict. Value to be replaced. If the value is a dict, then `value` is ignored or can be omitted, and `to_replace` must be a mapping between a value and a replacement. - :param value: bool, int, long, float, string, list or None. - The replacement value must be a bool, int, long, float, string or None. If `value` is a + :param value: bool, int, float, string, list or None. + The replacement value must be a bool, int, float, string or None. If `value` is a list, `value` should be of the same length and type as `to_replace`. If `value` is a scalar and `to_replace` is a sequence, then `value` is used as a replacement for each item in `to_replace`. @@ -1854,7 +1829,7 @@ def all_of(types): >>> all_of(bool)([True, False]) True - >>> all_of(basestring)(["a", 1]) + >>> all_of(str)(["a", 1]) False """ def all_of_(xs): @@ -1862,20 +1837,20 @@ def all_of_(xs): return all_of_ all_of_bool = all_of(bool) - all_of_str = all_of(basestring) - all_of_numeric = all_of((float, int, long)) + all_of_str = all_of(str) + all_of_numeric = all_of((float, int)) # Validate input types - valid_types = (bool, float, int, long, basestring, list, tuple) + valid_types = (bool, float, int, str, list, tuple) if not isinstance(to_replace, valid_types + (dict, )): raise ValueError( - "to_replace should be a bool, float, int, long, string, list, tuple, or dict. " + "to_replace should be a bool, float, int, string, list, tuple, or dict. " "Got {0}".format(type(to_replace))) if not isinstance(value, valid_types) and value is not None \ and not isinstance(to_replace, dict): raise ValueError("If to_replace is not a dict, value should be " - "a bool, float, int, long, string, list, tuple or None. " + "a bool, float, int, string, list, tuple or None. " "Got {0}".format(type(value))) if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)): @@ -1883,12 +1858,12 @@ def all_of_(xs): raise ValueError("to_replace and value lists should be of the same length. " "Got {0} and {1}".format(len(to_replace), len(value))) - if not (subset is None or isinstance(subset, (list, tuple, basestring))): + if not (subset is None or isinstance(subset, (list, tuple, str))): raise ValueError("subset should be a list or tuple of column names, " "column name or None. Got {0}".format(type(subset))) # Reshape input arguments if necessary - if isinstance(to_replace, (float, int, long, basestring)): + if isinstance(to_replace, (float, int, str)): to_replace = [to_replace] if isinstance(to_replace, dict): @@ -1896,11 +1871,11 @@ def all_of_(xs): if value is not None: warnings.warn("to_replace is a dict and value is not None. value will be ignored.") else: - if isinstance(value, (float, int, long, basestring)) or value is None: + if isinstance(value, (float, int, str)) or value is None: value = [value for _ in range(len(to_replace))] rep_dict = dict(zip(to_replace, value)) - if isinstance(subset, basestring): + if isinstance(subset, str): subset = [subset] # Verify we were not passed in mixed type generics. @@ -1957,10 +1932,10 @@ def approxQuantile(self, col, probabilities, relativeError): Added support for multiple columns. """ - if not isinstance(col, (basestring, list, tuple)): + if not isinstance(col, (str, list, tuple)): raise ValueError("col should be a string, list or tuple, but got %r" % type(col)) - isStr = isinstance(col, basestring) + isStr = isinstance(col, str) if isinstance(col, tuple): col = list(col) @@ -1968,7 +1943,7 @@ def approxQuantile(self, col, probabilities, relativeError): col = [col] for c in col: - if not isinstance(c, basestring): + if not isinstance(c, str): raise ValueError("columns should be strings, but got %r" % type(c)) col = _to_list(self._sc, col) @@ -1977,12 +1952,12 @@ def approxQuantile(self, col, probabilities, relativeError): if isinstance(probabilities, tuple): probabilities = list(probabilities) for p in probabilities: - if not isinstance(p, (float, int, long)) or p < 0 or p > 1: - raise ValueError("probabilities should be numerical (float, int, long) in [0,1].") + if not isinstance(p, (float, int)) or p < 0 or p > 1: + raise ValueError("probabilities should be numerical (float, int) in [0,1].") probabilities = _to_list(self._sc, probabilities) - if not isinstance(relativeError, (float, int, long)) or relativeError < 0: - raise ValueError("relativeError should be numerical (float, int, long) >= 0.") + if not isinstance(relativeError, (float, int)) or relativeError < 0: + raise ValueError("relativeError should be numerical (float, int) >= 0.") relativeError = float(relativeError) jaq = self._jdf.stat().approxQuantile(col, probabilities, relativeError) @@ -2000,9 +1975,9 @@ def corr(self, col1, col2, method=None): :param col2: The name of the second column :param method: The correlation method. Currently only supports "pearson" """ - if not isinstance(col1, basestring): + if not isinstance(col1, str): raise ValueError("col1 should be a string.") - if not isinstance(col2, basestring): + if not isinstance(col2, str): raise ValueError("col2 should be a string.") if not method: method = "pearson" @@ -2020,9 +1995,9 @@ def cov(self, col1, col2): :param col1: The name of the first column :param col2: The name of the second column """ - if not isinstance(col1, basestring): + if not isinstance(col1, str): raise ValueError("col1 should be a string.") - if not isinstance(col2, basestring): + if not isinstance(col2, str): raise ValueError("col2 should be a string.") return self._jdf.stat().cov(col1, col2) @@ -2042,9 +2017,9 @@ def crosstab(self, col1, col2): :param col2: The name of the second column. Distinct items will make the column names of the :class:`DataFrame`. """ - if not isinstance(col1, basestring): + if not isinstance(col1, str): raise ValueError("col1 should be a string.") - if not isinstance(col2, basestring): + if not isinstance(col2, str): raise ValueError("col2 should be a string.") return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx) @@ -2073,7 +2048,6 @@ def freqItems(self, cols, support=None): support = 0.01 return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols), support), self.sql_ctx) - @ignore_unicode_prefix @since(1.3) def withColumn(self, colName, col): """ @@ -2092,13 +2066,12 @@ def withColumn(self, colName, col): To avoid this, use :func:`select` with the multiple columns at once. >>> df.withColumn('age2', df.age + 2).collect() - [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] + [Row(age=2, name='Alice', age2=4), Row(age=5, name='Bob', age2=7)] """ assert isinstance(col, Column), "col should be Column" return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx) - @ignore_unicode_prefix @since(1.3) def withColumnRenamed(self, existing, new): """Returns a new :class:`DataFrame` by renaming an existing column. @@ -2108,12 +2081,11 @@ def withColumnRenamed(self, existing, new): :param new: string, new name of the column. >>> df.withColumnRenamed('age', 'age2').collect() - [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] + [Row(age2=2, name='Alice'), Row(age2=5, name='Bob')] """ return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sql_ctx) @since(1.4) - @ignore_unicode_prefix def drop(self, *cols): """Returns a new :class:`DataFrame` that drops the specified column. This is a no-op if schema doesn't contain the given column name(s). @@ -2122,23 +2094,23 @@ def drop(self, *cols): :class:`Column` to drop, or a list of string name of the columns to drop. >>> df.drop('age').collect() - [Row(name=u'Alice'), Row(name=u'Bob')] + [Row(name='Alice'), Row(name='Bob')] >>> df.drop(df.age).collect() - [Row(name=u'Alice'), Row(name=u'Bob')] + [Row(name='Alice'), Row(name='Bob')] >>> df.join(df2, df.name == df2.name, 'inner').drop(df.name).collect() - [Row(age=5, height=85, name=u'Bob')] + [Row(age=5, height=85, name='Bob')] >>> df.join(df2, df.name == df2.name, 'inner').drop(df2.name).collect() - [Row(age=5, name=u'Bob', height=85)] + [Row(age=5, name='Bob', height=85)] >>> df.join(df2, 'name', 'inner').drop('age', 'height').collect() - [Row(name=u'Bob')] + [Row(name='Bob')] """ if len(cols) == 1: col = cols[0] - if isinstance(col, basestring): + if isinstance(col, str): jdf = self._jdf.drop(col) elif isinstance(col, Column): jdf = self._jdf.drop(col._jc) @@ -2146,20 +2118,19 @@ def drop(self, *cols): raise TypeError("col should be a string or a Column") else: for col in cols: - if not isinstance(col, basestring): + if not isinstance(col, str): raise TypeError("each col in the param list should be a string") jdf = self._jdf.drop(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx) - @ignore_unicode_prefix def toDF(self, *cols): """Returns a new :class:`DataFrame` that with new specified column names :param cols: list of new column names (string) >>> df.toDF('f1', 'f2').collect() - [Row(f1=2, f2=u'Alice'), Row(f1=5, f2=u'Bob')] + [Row(f1=2, f2='Alice'), Row(f1=5, f2='Bob')] """ jdf = self._jdf.toDF(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx) @@ -2347,7 +2318,6 @@ def _test(): from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext, SparkSession import pyspark.sql.dataframe - from pyspark.sql.functions import from_unixtime globs = pyspark.sql.dataframe.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc @@ -2356,16 +2326,16 @@ def _test(): globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) - globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() - globs['df3'] = sc.parallelize([Row(name='Alice', age=2), - Row(name='Bob', age=5)]).toDF() - globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), - Row(name='Bob', age=5, height=None), - Row(name='Tom', age=None, height=None), - Row(name=None, age=None, height=None)]).toDF() - globs['df5'] = sc.parallelize([Row(name='Alice', spy=False, age=10), - Row(name='Bob', spy=None, age=5), - Row(name='Mallory', spy=True, age=None)]).toDF() + globs['df2'] = sc.parallelize([Row(height=80, name='Tom'), Row(height=85, name='Bob')]).toDF() + globs['df3'] = sc.parallelize([Row(age=2, name='Alice'), + Row(age=5, name='Bob')]).toDF() + globs['df4'] = sc.parallelize([Row(age=10, height=80, name='Alice'), + Row(age=5, height=None, name='Bob'), + Row(age=None, height=None, name='Tom'), + Row(age=None, height=None, name=None)]).toDF() + globs['df5'] = sc.parallelize([Row(age=10, name='Alice', spy=False), + Row(age=5, name='Bob', spy=None), + Row(age=None, name='Mallory', spy=True)]).toDF() globs['sdf'] = sc.parallelize([Row(name='Tom', time=1479441846), Row(name='Bob', time=1479442946)]).toDF() diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b5a7c18904b14..5a352104c4eca 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -22,14 +22,8 @@ import functools import warnings -if sys.version < "3": - from itertools import imap as map - -if sys.version >= '3': - basestring = str - from pyspark import since, SparkContext -from pyspark.rdd import ignore_unicode_prefix, PythonEvalType +from pyspark.rdd import PythonEvalType from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal, \ _create_column_from_name from pyspark.sql.dataframe import DataFrame @@ -88,14 +82,14 @@ def _(col1, col2): # if they are not columns or strings. if isinstance(col1, Column): arg1 = col1._jc - elif isinstance(col1, basestring): + elif isinstance(col1, str): arg1 = _create_column_from_name(col1) else: arg1 = float(col1) if isinstance(col2, Column): arg2 = col2._jc - elif isinstance(col2, basestring): + elif isinstance(col2, str): arg2 = _create_column_from_name(col2) else: arg2 = float(col2) @@ -648,7 +642,6 @@ def percentile_approx(col, percentage, accuracy=10000): return Column(sc._jvm.functions.percentile_approx(_to_java_column(col), percentage, accuracy)) -@ignore_unicode_prefix @since(1.4) def rand(seed=None): """Generates a random column with independent and identically distributed (i.i.d.) samples @@ -657,8 +650,8 @@ def rand(seed=None): .. note:: The function is non-deterministic in general case. >>> df.withColumn('rand', rand(seed=42) * 3).collect() - [Row(age=2, name=u'Alice', rand=2.4052597283576684), - Row(age=5, name=u'Bob', rand=2.3913904055683974)] + [Row(age=2, name='Alice', rand=2.4052597283576684), + Row(age=5, name='Bob', rand=2.3913904055683974)] """ sc = SparkContext._active_spark_context if seed is not None: @@ -668,7 +661,6 @@ def rand(seed=None): return Column(jc) -@ignore_unicode_prefix @since(1.4) def randn(seed=None): """Generates a column with independent and identically distributed (i.i.d.) samples from @@ -677,8 +669,8 @@ def randn(seed=None): .. note:: The function is non-deterministic in general case. >>> df.withColumn('randn', randn(seed=42)).collect() - [Row(age=2, name=u'Alice', randn=1.1027054481455365), - Row(age=5, name=u'Bob', randn=0.7400395449950132)] + [Row(age=2, name='Alice', randn=1.1027054481455365), + Row(age=5, name='Bob', randn=0.7400395449950132)] """ sc = SparkContext._active_spark_context if seed is not None: @@ -774,7 +766,6 @@ def expr(str): return Column(sc._jvm.functions.expr(str)) -@ignore_unicode_prefix @since(1.4) def struct(*cols): """Creates a new struct column. @@ -782,9 +773,9 @@ def struct(*cols): :param cols: list of column names (string) or list of :class:`Column` expressions >>> df.select(struct('age', 'name').alias("struct")).collect() - [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] + [Row(struct=Row(age=2, name='Alice')), Row(struct=Row(age=5, name='Bob'))] >>> df.select(struct([df.age, df.name]).alias("struct")).collect() - [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] + [Row(struct=Row(age=2, name='Alice')), Row(struct=Row(age=5, name='Bob'))] """ sc = SparkContext._active_spark_context if len(cols) == 1 and isinstance(cols[0], (list, set)): @@ -879,14 +870,13 @@ def log2(col): @since(1.5) -@ignore_unicode_prefix def conv(col, fromBase, toBase): """ Convert a number in a string column from one base to another. >>> df = spark.createDataFrame([("010101",)], ['n']) >>> df.select(conv(df.n, 2, 16).alias('hex')).collect() - [Row(hex=u'15')] + [Row(hex='15')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.conv(_to_java_column(col), fromBase, toBase)) @@ -976,7 +966,6 @@ def current_timestamp(): return Column(sc._jvm.functions.current_timestamp()) -@ignore_unicode_prefix @since(1.5) def date_format(date, format): """ @@ -992,7 +981,7 @@ def date_format(date, format): >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) >>> df.select(date_format('dt', 'MM/dd/yyy').alias('date')).collect() - [Row(date=u'04/08/2015')] + [Row(date='04/08/2015')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.date_format(_to_java_column(date), format)) @@ -1310,7 +1299,6 @@ def last_day(date): return Column(sc._jvm.functions.last_day(_to_java_column(date))) -@ignore_unicode_prefix @since(1.5) def from_unixtime(timestamp, format="yyyy-MM-dd HH:mm:ss"): """ @@ -1321,7 +1309,7 @@ def from_unixtime(timestamp, format="yyyy-MM-dd HH:mm:ss"): >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") >>> time_df = spark.createDataFrame([(1428476400,)], ['unix_time']) >>> time_df.select(from_unixtime('unix_time').alias('ts')).collect() - [Row(ts=u'2015-04-08 00:00:00')] + [Row(ts='2015-04-08 00:00:00')] >>> spark.conf.unset("spark.sql.session.timeZone") """ sc = SparkContext._active_spark_context @@ -1447,7 +1435,6 @@ def timestamp_seconds(col): @since(2.0) -@ignore_unicode_prefix def window(timeColumn, windowDuration, slideDuration=None, startTime=None): """Bucketize rows into one or more time windows given a timestamp specifying column. Window starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window @@ -1471,7 +1458,7 @@ def window(timeColumn, windowDuration, slideDuration=None, startTime=None): >>> w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum")) >>> w.select(w.window.start.cast("string").alias("start"), ... w.window.end.cast("string").alias("end"), "sum").collect() - [Row(start=u'2016-03-11 09:00:05', end=u'2016-03-11 09:00:10', sum=1)] + [Row(start='2016-03-11 09:00:05', end='2016-03-11 09:00:10', sum=1)] """ def check_string_field(field, fieldName): if not field or type(field) is not str: @@ -1498,7 +1485,6 @@ def check_string_field(field, fieldName): # ---------------------------- misc functions ---------------------------------- @since(1.5) -@ignore_unicode_prefix def crc32(col): """ Calculates the cyclic redundancy check value (CRC32) of a binary column and @@ -1511,33 +1497,30 @@ def crc32(col): return Column(sc._jvm.functions.crc32(_to_java_column(col))) -@ignore_unicode_prefix @since(1.5) def md5(col): """Calculates the MD5 digest and returns the value as a 32 character hex string. >>> spark.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() - [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')] + [Row(hash='902fbdd2b1df0c4f70b4a5d23525e932')] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.md5(_to_java_column(col)) return Column(jc) -@ignore_unicode_prefix @since(1.5) def sha1(col): """Returns the hex string result of SHA-1. >>> spark.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() - [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] + [Row(hash='3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.sha1(_to_java_column(col)) return Column(jc) -@ignore_unicode_prefix @since(1.5) def sha2(col, numBits): """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, @@ -1546,9 +1529,9 @@ def sha2(col, numBits): >>> digests = df.select(sha2(df.name, 256).alias('s')).collect() >>> digests[0] - Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043') + Row(s='3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043') >>> digests[1] - Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961') + Row(s='cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961') """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.sha2(_to_java_column(col), numBits) @@ -1600,7 +1583,6 @@ def xxhash64(*cols): @since(1.5) -@ignore_unicode_prefix def concat_ws(sep, *cols): """ Concatenates multiple input string columns together into a single string column, @@ -1608,7 +1590,7 @@ def concat_ws(sep, *cols): >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect() - [Row(s=u'abcd-123')] + [Row(s='abcd-123')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.concat_ws(sep, _to_seq(sc, cols, _to_java_column))) @@ -1634,7 +1616,6 @@ def encode(col, charset): return Column(sc._jvm.functions.encode(_to_java_column(col), charset)) -@ignore_unicode_prefix @since(1.5) def format_number(col, d): """ @@ -1645,13 +1626,12 @@ def format_number(col, d): :param d: the N decimal places >>> spark.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() - [Row(v=u'5.0000')] + [Row(v='5.0000')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.format_number(_to_java_column(col), d)) -@ignore_unicode_prefix @since(1.5) def format_string(format, *cols): """ @@ -1663,7 +1643,7 @@ def format_string(format, *cols): >>> df = spark.createDataFrame([(5, "hello")], ['a', 'b']) >>> df.select(format_string('%d %s', df.a, df.b).alias('v')).collect() - [Row(v=u'5 hello')] + [Row(v='5 hello')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.format_string(format, _to_seq(sc, cols, _to_java_column))) @@ -1721,7 +1701,6 @@ def overlay(src, replace, pos, len=-1): @since(1.5) -@ignore_unicode_prefix def substring(str, pos, len): """ Substring starts at `pos` and is of length `len` when str is String type or @@ -1732,14 +1711,13 @@ def substring(str, pos, len): >>> df = spark.createDataFrame([('abcd',)], ['s',]) >>> df.select(substring(df.s, 1, 2).alias('s')).collect() - [Row(s=u'ab')] + [Row(s='ab')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.substring(_to_java_column(str), pos, len)) @since(1.5) -@ignore_unicode_prefix def substring_index(str, delim, count): """ Returns the substring from string str before count occurrences of the delimiter delim. @@ -1749,15 +1727,14 @@ def substring_index(str, delim, count): >>> df = spark.createDataFrame([('a.b.c.d',)], ['s']) >>> df.select(substring_index(df.s, '.', 2).alias('s')).collect() - [Row(s=u'a.b')] + [Row(s='a.b')] >>> df.select(substring_index(df.s, '.', -3).alias('s')).collect() - [Row(s=u'b.c.d')] + [Row(s='b.c.d')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.substring_index(_to_java_column(str), delim, count)) -@ignore_unicode_prefix @since(1.5) def levenshtein(left, right): """Computes the Levenshtein distance of the two given strings. @@ -1792,49 +1769,45 @@ def locate(substr, str, pos=1): @since(1.5) -@ignore_unicode_prefix def lpad(col, len, pad): """ Left-pad the string column to width `len` with `pad`. >>> df = spark.createDataFrame([('abcd',)], ['s',]) >>> df.select(lpad(df.s, 6, '#').alias('s')).collect() - [Row(s=u'##abcd')] + [Row(s='##abcd')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.lpad(_to_java_column(col), len, pad)) @since(1.5) -@ignore_unicode_prefix def rpad(col, len, pad): """ Right-pad the string column to width `len` with `pad`. >>> df = spark.createDataFrame([('abcd',)], ['s',]) >>> df.select(rpad(df.s, 6, '#').alias('s')).collect() - [Row(s=u'abcd##')] + [Row(s='abcd##')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.rpad(_to_java_column(col), len, pad)) @since(1.5) -@ignore_unicode_prefix def repeat(col, n): """ Repeats a string column n times, and returns it as a new string column. >>> df = spark.createDataFrame([('ab',)], ['s',]) >>> df.select(repeat(df.s, 3).alias('s')).collect() - [Row(s=u'ababab')] + [Row(s='ababab')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.repeat(_to_java_column(col), n)) @since(1.5) -@ignore_unicode_prefix def split(str, pattern, limit=-1): """ Splits str around matches of the given pattern. @@ -1855,15 +1828,14 @@ def split(str, pattern, limit=-1): >>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',]) >>> df.select(split(df.s, '[ABC]', 2).alias('s')).collect() - [Row(s=[u'one', u'twoBthreeC'])] + [Row(s=['one', 'twoBthreeC'])] >>> df.select(split(df.s, '[ABC]', -1).alias('s')).collect() - [Row(s=[u'one', u'two', u'three', u''])] + [Row(s=['one', 'two', 'three', ''])] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.split(_to_java_column(str), pattern, limit)) -@ignore_unicode_prefix @since(1.5) def regexp_extract(str, pattern, idx): r"""Extract a specific group matched by a Java regex, from the specified string column. @@ -1871,73 +1843,68 @@ def regexp_extract(str, pattern, idx): >>> df = spark.createDataFrame([('100-200',)], ['str']) >>> df.select(regexp_extract('str', r'(\d+)-(\d+)', 1).alias('d')).collect() - [Row(d=u'100')] + [Row(d='100')] >>> df = spark.createDataFrame([('foo',)], ['str']) >>> df.select(regexp_extract('str', r'(\d+)', 1).alias('d')).collect() - [Row(d=u'')] + [Row(d='')] >>> df = spark.createDataFrame([('aaaac',)], ['str']) >>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect() - [Row(d=u'')] + [Row(d='')] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx) return Column(jc) -@ignore_unicode_prefix @since(1.5) def regexp_replace(str, pattern, replacement): r"""Replace all substrings of the specified string value that match regexp with rep. >>> df = spark.createDataFrame([('100-200',)], ['str']) >>> df.select(regexp_replace('str', r'(\d+)', '--').alias('d')).collect() - [Row(d=u'-----')] + [Row(d='-----')] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement) return Column(jc) -@ignore_unicode_prefix @since(1.5) def initcap(col): """Translate the first letter of each word to upper case in the sentence. >>> spark.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).collect() - [Row(v=u'Ab Cd')] + [Row(v='Ab Cd')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.initcap(_to_java_column(col))) @since(1.5) -@ignore_unicode_prefix def soundex(col): """ Returns the SoundEx encoding for a string >>> df = spark.createDataFrame([("Peters",),("Uhrbach",)], ['name']) >>> df.select(soundex(df.name).alias("soundex")).collect() - [Row(soundex=u'P362'), Row(soundex=u'U612')] + [Row(soundex='P362'), Row(soundex='U612')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.soundex(_to_java_column(col))) -@ignore_unicode_prefix @since(1.5) def bin(col): """Returns the string representation of the binary value of the given column. >>> df.select(bin(df.age).alias('c')).collect() - [Row(c=u'10'), Row(c=u'101')] + [Row(c='10'), Row(c='101')] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.bin(_to_java_column(col)) return Column(jc) -@ignore_unicode_prefix @since(1.5) def hex(col): """Computes hex value of the given column, which could be :class:`pyspark.sql.types.StringType`, @@ -1945,14 +1912,13 @@ def hex(col): :class:`pyspark.sql.types.LongType`. >>> spark.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() - [Row(hex(a)=u'414243', hex(b)=u'3')] + [Row(hex(a)='414243', hex(b)='3')] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.hex(_to_java_column(col)) return Column(jc) -@ignore_unicode_prefix @since(1.5) def unhex(col): """Inverse of hex. Interprets each pair of characters as a hexadecimal number @@ -1965,7 +1931,6 @@ def unhex(col): return Column(sc._jvm.functions.unhex(_to_java_column(col))) -@ignore_unicode_prefix @since(1.5) def length(col): """Computes the character length of string data or number of bytes of binary data. @@ -1979,7 +1944,6 @@ def length(col): return Column(sc._jvm.functions.length(_to_java_column(col))) -@ignore_unicode_prefix @since(1.5) def translate(srcCol, matching, replace): """A function translate any character in the `srcCol` by a character in `matching`. @@ -1989,7 +1953,7 @@ def translate(srcCol, matching, replace): >>> spark.createDataFrame([('translate',)], ['a']).select(translate('a', "rnlt", "123") \\ ... .alias('r')).collect() - [Row(r=u'1a2s3ae')] + [Row(r='1a2s3ae')] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.translate(_to_java_column(srcCol), matching, replace)) @@ -1997,7 +1961,6 @@ def translate(srcCol, matching, replace): # ---------------------- Collection functions ------------------------------ -@ignore_unicode_prefix @since(2.0) def create_map(*cols): """Creates a new map column. @@ -2006,9 +1969,9 @@ def create_map(*cols): grouped as key-value pairs, e.g. (key1, value1, key2, value2, ...). >>> df.select(create_map('name', 'age').alias("map")).collect() - [Row(map={u'Alice': 2}), Row(map={u'Bob': 5})] + [Row(map={'Alice': 2}), Row(map={'Bob': 5})] >>> df.select(create_map([df.name, df.age]).alias("map")).collect() - [Row(map={u'Alice': 2}), Row(map={u'Bob': 5})] + [Row(map={'Alice': 2}), Row(map={'Bob': 5})] """ sc = SparkContext._active_spark_context if len(cols) == 1 and isinstance(cols[0], (list, set)): @@ -2108,7 +2071,6 @@ def slice(x, start, length): return Column(sc._jvm.functions.slice(_to_java_column(x), start, length)) -@ignore_unicode_prefix @since(2.4) def array_join(col, delimiter, null_replacement=None): """ @@ -2117,9 +2079,9 @@ def array_join(col, delimiter, null_replacement=None): >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) >>> df.select(array_join(df.data, ",").alias("joined")).collect() - [Row(joined=u'a,b,c'), Row(joined=u'a')] + [Row(joined='a,b,c'), Row(joined='a')] >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect() - [Row(joined=u'a,b,c'), Row(joined=u'a,NULL')] + [Row(joined='a,b,c'), Row(joined='a,NULL')] """ sc = SparkContext._active_spark_context if null_replacement is None: @@ -2130,7 +2092,6 @@ def array_join(col, delimiter, null_replacement=None): @since(1.5) -@ignore_unicode_prefix def concat(*cols): """ Concatenates multiple input columns together into a single column. @@ -2138,7 +2099,7 @@ def concat(*cols): >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) >>> df.select(concat(df.s, df.d).alias('s')).collect() - [Row(s=u'abcd123')] + [Row(s='abcd123')] >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c']) >>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect() @@ -2165,7 +2126,6 @@ def array_position(col, value): return Column(sc._jvm.functions.array_position(_to_java_column(col), value)) -@ignore_unicode_prefix @since(2.4) def element_at(col, extraction): """ @@ -2179,7 +2139,7 @@ def element_at(col, extraction): >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) >>> df.select(element_at(df.data, 1)).collect() - [Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)] + [Row(element_at(data, 1)='a'), Row(element_at(data, 1)=None)] >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data']) >>> df.select(element_at(df.data, lit("a"))).collect() @@ -2221,7 +2181,6 @@ def array_distinct(col): return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) -@ignore_unicode_prefix @since(2.4) def array_intersect(col1, col2): """ @@ -2234,13 +2193,12 @@ def array_intersect(col1, col2): >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_intersect(df.c1, df.c2)).collect() - [Row(array_intersect(c1, c2)=[u'a', u'c'])] + [Row(array_intersect(c1, c2)=['a', 'c'])] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_intersect(_to_java_column(col1), _to_java_column(col2))) -@ignore_unicode_prefix @since(2.4) def array_union(col1, col2): """ @@ -2253,13 +2211,12 @@ def array_union(col1, col2): >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_union(df.c1, df.c2)).collect() - [Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f'])] + [Row(array_union(c1, c2)=['b', 'a', 'c', 'd', 'f'])] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2))) -@ignore_unicode_prefix @since(2.4) def array_except(col1, col2): """ @@ -2272,7 +2229,7 @@ def array_except(col1, col2): >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_except(df.c1, df.c2)).collect() - [Row(array_except(c1, c2)=[u'b'])] + [Row(array_except(c1, c2)=['b'])] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_except(_to_java_column(col1), _to_java_column(col2))) @@ -2397,7 +2354,6 @@ def posexplode_outer(col): return Column(jc) -@ignore_unicode_prefix @since(1.6) def get_json_object(col, path): """ @@ -2411,14 +2367,13 @@ def get_json_object(col, path): >>> df = spark.createDataFrame(data, ("key", "jstring")) >>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \\ ... get_json_object(df.jstring, '$.f2').alias("c1") ).collect() - [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)] + [Row(key='1', c0='value1', c1='value2'), Row(key='2', c0='value12', c1=None)] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.get_json_object(_to_java_column(col), path) return Column(jc) -@ignore_unicode_prefix @since(1.6) def json_tuple(col, *fields): """Creates a new row for a json column according to the given field names. @@ -2429,14 +2384,13 @@ def json_tuple(col, *fields): >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] >>> df = spark.createDataFrame(data, ("key", "jstring")) >>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect() - [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)] + [Row(key='1', c0='value1', c1='value2'), Row(key='2', c0='value12', c1=None)] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.json_tuple(_to_java_column(col), _to_seq(sc, fields)) return Column(jc) -@ignore_unicode_prefix @since(2.1) def from_json(col, schema, options={}): """ @@ -2460,7 +2414,7 @@ def from_json(col, schema, options={}): >>> df.select(from_json(df.value, "a INT").alias("json")).collect() [Row(json=Row(a=1))] >>> df.select(from_json(df.value, "MAP").alias("json")).collect() - [Row(json={u'a': 1})] + [Row(json={'a': 1})] >>> data = [(1, '''[{"a": 1}]''')] >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) >>> df = spark.createDataFrame(data, ("key", "value")) @@ -2485,7 +2439,6 @@ def from_json(col, schema, options={}): return Column(jc) -@ignore_unicode_prefix @since(2.1) def to_json(col, options={}): """ @@ -2499,26 +2452,26 @@ def to_json(col, options={}): >>> from pyspark.sql import Row >>> from pyspark.sql.types import * - >>> data = [(1, Row(name='Alice', age=2))] + >>> data = [(1, Row(age=2, name='Alice'))] >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_json(df.value).alias("json")).collect() - [Row(json=u'{"age":2,"name":"Alice"}')] - >>> data = [(1, [Row(name='Alice', age=2), Row(name='Bob', age=3)])] + [Row(json='{"age":2,"name":"Alice"}')] + >>> data = [(1, [Row(age=2, name='Alice'), Row(age=3, name='Bob')])] >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_json(df.value).alias("json")).collect() - [Row(json=u'[{"age":2,"name":"Alice"},{"age":3,"name":"Bob"}]')] + [Row(json='[{"age":2,"name":"Alice"},{"age":3,"name":"Bob"}]')] >>> data = [(1, {"name": "Alice"})] >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_json(df.value).alias("json")).collect() - [Row(json=u'{"name":"Alice"}')] + [Row(json='{"name":"Alice"}')] >>> data = [(1, [{"name": "Alice"}, {"name": "Bob"}])] >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_json(df.value).alias("json")).collect() - [Row(json=u'[{"name":"Alice"},{"name":"Bob"}]')] + [Row(json='[{"name":"Alice"},{"name":"Bob"}]')] >>> data = [(1, ["Alice", "Bob"])] >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_json(df.value).alias("json")).collect() - [Row(json=u'["Alice","Bob"]')] + [Row(json='["Alice","Bob"]')] """ sc = SparkContext._active_spark_context @@ -2526,7 +2479,6 @@ def to_json(col, options={}): return Column(jc) -@ignore_unicode_prefix @since(2.4) def schema_of_json(json, options={}): """ @@ -2540,12 +2492,12 @@ def schema_of_json(json, options={}): >>> df = spark.range(1) >>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect() - [Row(json=u'struct')] + [Row(json='struct')] >>> schema = schema_of_json('{a: 1}', {'allowUnquotedFieldNames':'true'}) >>> df.select(schema.alias("json")).collect() - [Row(json=u'struct')] + [Row(json='struct')] """ - if isinstance(json, basestring): + if isinstance(json, str): col = _create_column_from_literal(json) elif isinstance(json, Column): col = _to_java_column(json) @@ -2557,7 +2509,6 @@ def schema_of_json(json, options={}): return Column(jc) -@ignore_unicode_prefix @since(3.0) def schema_of_csv(csv, options={}): """ @@ -2568,11 +2519,11 @@ def schema_of_csv(csv, options={}): >>> df = spark.range(1) >>> df.select(schema_of_csv(lit('1|a'), {'sep':'|'}).alias("csv")).collect() - [Row(csv=u'struct<_c0:int,_c1:string>')] + [Row(csv='struct<_c0:int,_c1:string>')] >>> df.select(schema_of_csv('1|a', {'sep':'|'}).alias("csv")).collect() - [Row(csv=u'struct<_c0:int,_c1:string>')] + [Row(csv='struct<_c0:int,_c1:string>')] """ - if isinstance(csv, basestring): + if isinstance(csv, str): col = _create_column_from_literal(csv) elif isinstance(csv, Column): col = _to_java_column(csv) @@ -2584,7 +2535,6 @@ def schema_of_csv(csv, options={}): return Column(jc) -@ignore_unicode_prefix @since(3.0) def to_csv(col, options={}): """ @@ -2595,10 +2545,10 @@ def to_csv(col, options={}): :param options: options to control converting. accepts the same options as the CSV datasource. >>> from pyspark.sql import Row - >>> data = [(1, Row(name='Alice', age=2))] + >>> data = [(1, Row(age=2, name='Alice'))] >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_csv(df.value).alias("csv")).collect() - [Row(csv=u'2,Alice')] + [Row(csv='2,Alice')] """ sc = SparkContext._active_spark_context @@ -2705,7 +2655,6 @@ def shuffle(col): @since(1.5) -@ignore_unicode_prefix def reverse(col): """ Collection function: returns a reversed string or an array with reverse order of elements. @@ -2714,7 +2663,7 @@ def reverse(col): >>> df = spark.createDataFrame([('Spark SQL',)], ['data']) >>> df.select(reverse(df.data).alias('s')).collect() - [Row(s=u'LQS krapS')] + [Row(s='LQS krapS')] >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data']) >>> df.select(reverse(df.data).alias('r')).collect() [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] @@ -2820,7 +2769,6 @@ def map_from_entries(col): return Column(sc._jvm.functions.map_from_entries(_to_java_column(col))) -@ignore_unicode_prefix @since(2.4) def array_repeat(col, count): """ @@ -2828,7 +2776,7 @@ def array_repeat(col, count): >>> df = spark.createDataFrame([('ab',)], ['data']) >>> df.select(array_repeat(df.data, 3).alias('r')).collect() - [Row(r=[u'ab', u'ab', u'ab'])] + [Row(r=['ab', 'ab', 'ab'])] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_repeat( @@ -2898,7 +2846,6 @@ def sequence(start, stop, step=None): _to_java_column(start), _to_java_column(stop), _to_java_column(step))) -@ignore_unicode_prefix @since(3.0) def from_csv(col, schema, options={}): """ @@ -2920,11 +2867,11 @@ def from_csv(col, schema, options={}): >>> df = spark.createDataFrame(data, ("value",)) >>> options = {'ignoreLeadingWhiteSpace': True} >>> df.select(from_csv(df.value, "s string", options).alias("csv")).collect() - [Row(csv=Row(s=u'abc'))] + [Row(csv=Row(s='abc'))] """ sc = SparkContext._active_spark_context - if isinstance(schema, basestring): + if isinstance(schema, str): schema = _create_column_from_literal(schema) elif isinstance(schema, Column): schema = _to_java_column(schema) @@ -2984,20 +2931,6 @@ def _get_lambda_parameters(f): return parameters -def _get_lambda_parameters_legacy(f): - # TODO (SPARK-29909) Remove once 2.7 support is dropped - import inspect - - spec = inspect.getargspec(f) - if not 1 <= len(spec.args) <= 3 or spec.varargs or spec.keywords: - raise ValueError( - "f should take between 1 and 3 arguments, but provided function takes {}".format( - spec - ) - ) - return spec.args - - def _create_lambda(f): """ Create `o.a.s.sql.expressions.LambdaFunction` corresponding @@ -3008,10 +2941,7 @@ def _create_lambda(f): - (Column, Column) -> Column: ... - (Column, Column, Column) -> Column: ... """ - if sys.version_info >= (3, 3): - parameters = _get_lambda_parameters(f) - else: - parameters = _get_lambda_parameters_legacy(f) + parameters = _get_lambda_parameters(f) sc = SparkContext._active_spark_context expressions = sc._jvm.org.apache.spark.sql.catalyst.expressions @@ -3481,9 +3411,9 @@ def udf(f=None, returnType=StringType()): evalType=PythonEvalType.SQL_BATCHED_UDF) -blacklist = ['map', 'since', 'ignore_unicode_prefix'] +ignored_fns = ['map', 'since'] __all__ = [k for k, v in globals().items() - if not k.startswith('_') and k[0].islower() and callable(v) and k not in blacklist] + if not k.startswith('_') and k[0].islower() and callable(v) and k not in ignored_fns] __all__ += ["PandasUDFType"] __all__.sort() @@ -3500,7 +3430,7 @@ def _test(): sc = spark.sparkContext globs['sc'] = sc globs['spark'] = spark - globs['df'] = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) + globs['df'] = spark.createDataFrame([Row(age=2, name='Alice'), Row(age=5, name='Bob')]) (failure_count, test_count) = doctest.testmod( pyspark.sql.functions, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index ac826bc64ad7e..83e2baa8f0002 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -18,7 +18,6 @@ import sys from pyspark import since -from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import Column, _to_seq from pyspark.sql.dataframe import DataFrame from pyspark.sql.pandas.group_ops import PandasGroupedOpsMixin @@ -60,7 +59,6 @@ def __init__(self, jgd, df): self._df = df self.sql_ctx = df.sql_ctx - @ignore_unicode_prefix @since(1.3) def agg(self, *exprs): """Compute aggregates and returns the result as a :class:`DataFrame`. @@ -91,18 +89,18 @@ def agg(self, *exprs): >>> gdf = df.groupBy(df.name) >>> sorted(gdf.agg({"*": "count"}).collect()) - [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)] + [Row(name='Alice', count(1)=1), Row(name='Bob', count(1)=1)] >>> from pyspark.sql import functions as F >>> sorted(gdf.agg(F.min(df.age)).collect()) - [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] + [Row(name='Alice', min(age)=2), Row(name='Bob', min(age)=5)] >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> @pandas_udf('int', PandasUDFType.GROUPED_AGG) # doctest: +SKIP ... def min_udf(v): ... return v.min() >>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP - [Row(name=u'Alice', min_udf(age)=2), Row(name=u'Bob', min_udf(age)=5)] + [Row(name='Alice', min_udf(age)=2), Row(name='Bob', min_udf(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index e6d8e9f24a557..3842bc2357c6c 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -16,11 +16,6 @@ # import sys import warnings -if sys.version >= '3': - basestring = unicode = str - xrange = range -else: - from itertools import izip as zip from collections import Counter from pyspark import since @@ -29,7 +24,6 @@ from pyspark.sql.types import IntegralType from pyspark.sql.types import * from pyspark.traceback_utils import SCCallSiteSync -from pyspark.util import _exception_message class PandasConversionMixin(object): @@ -84,7 +78,7 @@ def toPandas(self): "failed by the reason below:\n %s\n" "Attempting non-optimization as " "'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to " - "true." % _exception_message(e)) + "true." % str(e)) warnings.warn(msg) use_arrow = False else: @@ -93,7 +87,7 @@ def toPandas(self): "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has " "reached the error below and will not continue because automatic fallback " "with 'spark.sql.execution.arrow.pyspark.fallback.enabled' has been set to " - "false.\n %s" % _exception_message(e)) + "false.\n %s" % str(e)) warnings.warn(msg) raise @@ -130,7 +124,7 @@ def toPandas(self): "reached the error below and can not continue. Note that " "'spark.sql.execution.arrow.pyspark.fallback.enabled' does not have an " "effect on failures in the middle of " - "computation.\n %s" % _exception_message(e)) + "computation.\n %s" % str(e)) warnings.warn(msg) raise @@ -268,7 +262,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr # If no schema supplied by user then get the names of columns only if schema is None: - schema = [str(x) if not isinstance(x, basestring) else + schema = [str(x) if not isinstance(x, str) else (x.encode('utf-8') if not isinstance(x, str) else x) for x in data.columns] @@ -276,8 +270,6 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr try: return self._create_from_pandas_with_arrow(data, schema, timezone) except Exception as e: - from pyspark.util import _exception_message - if self._wrapped._conf.arrowPySparkFallbackEnabled(): msg = ( "createDataFrame attempted Arrow optimization because " @@ -285,7 +277,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr "failed by the reason below:\n %s\n" "Attempting non-optimization as " "'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to " - "true." % _exception_message(e)) + "true." % str(e)) warnings.warn(msg) else: msg = ( @@ -293,7 +285,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has " "reached the error below and will not continue because automatic " "fallback with 'spark.sql.execution.arrow.pyspark.fallback.enabled' " - "has been set to false.\n %s" % _exception_message(e)) + "has been set to false.\n %s" % str(e)) warnings.warn(msg) raise data = self._convert_from_pandas(data, schema, timezone) @@ -358,7 +350,7 @@ def _get_numpy_record_dtype(self, rec): col_names = cur_dtypes.names record_type_list = [] has_rec_fix = False - for i in xrange(len(cur_dtypes)): + for i in range(len(cur_dtypes)): curr_type = cur_dtypes[i] # If type is a datetime64 timestamp, convert to microseconds # NOTE: if dtype is datetime[ns] then np.record.tolist() will output values as longs, @@ -413,7 +405,7 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): # Slice the DataFrame to be batched step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up - pdf_slices = (pdf.iloc[start:start + step] for start in xrange(0, len(pdf), step)) + pdf_slices = (pdf.iloc[start:start + step] for start in range(0, len(pdf), step)) # Create list of Arrow (columns, type) for serializer dump_stream arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)] diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 094dc357b6822..ba4dec82d4eb4 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -18,6 +18,7 @@ import functools import sys import warnings +from inspect import getfullargspec from pyspark import since from pyspark.rdd import PythonEvalType @@ -25,7 +26,6 @@ from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version from pyspark.sql.types import DataType from pyspark.sql.udf import _create_udf -from pyspark.util import _get_argspec class PandasUDFType(object): @@ -371,30 +371,29 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: def _create_pandas_udf(f, returnType, evalType): - argspec = _get_argspec(f) + argspec = getfullargspec(f) # pandas UDF by type hints. - if sys.version_info >= (3, 6): - from inspect import signature - - if evalType in [PythonEvalType.SQL_SCALAR_PANDAS_UDF, - PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]: - warnings.warn( - "In Python 3.6+ and Spark 3.0+, it is preferred to specify type hints for " - "pandas UDF instead of specifying pandas UDF type which will be deprecated " - "in the future releases. See SPARK-28264 for more details.", UserWarning) - elif evalType in [PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, - PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF]: - # In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is being triggered - # at `apply` instead. - # In case of 'SQL_MAP_PANDAS_ITER_UDF' and 'SQL_COGROUPED_MAP_PANDAS_UDF', the - # evaluation type will always be set. - pass - elif len(argspec.annotations) > 0: - evalType = infer_eval_type(signature(f)) - assert evalType is not None + from inspect import signature + + if evalType in [PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]: + warnings.warn( + "In Python 3.6+ and Spark 3.0+, it is preferred to specify type hints for " + "pandas UDF instead of specifying pandas UDF type which will be deprecated " + "in the future releases. See SPARK-28264 for more details.", UserWarning) + elif evalType in [PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF]: + # In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is being triggered + # at `apply` instead. + # In case of 'SQL_MAP_PANDAS_ITER_UDF' and 'SQL_COGROUPED_MAP_PANDAS_UDF', the + # evaluation type will always be set. + pass + elif len(argspec.annotations) > 0: + evalType = infer_eval_type(signature(f)) + assert evalType is not None if evalType is None: # Set default is scalar UDF. diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 42562e1fb9c46..4b91c6a0f8730 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -19,13 +19,6 @@ Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details. """ -import sys -if sys.version < '3': - from itertools import izip as zip -else: - basestring = unicode = str - xrange = range - from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer @@ -67,7 +60,7 @@ def load_stream(self, stream): raise RuntimeError("An error occurred while calling " "ArrowCollectSerializer.load_stream: {}".format(error_msg)) batch_order = [] - for i in xrange(num): + for i in range(num): index = read_int(stream) batch_order.append(index) yield batch_order @@ -180,7 +173,7 @@ def create_array(s, t): if len(s) == 0 and len(s.columns) == 0: arrs_names = [(pa.array([], type=field.type), field.name) for field in t] # Assign result columns by schema name if user labeled with strings - elif self._assign_cols_by_name and any(isinstance(name, basestring) + elif self._assign_cols_by_name and any(isinstance(name, str) for name in s.columns): arrs_names = [(create_array(s[field.name], field.type), field.name) for field in t] @@ -194,7 +187,7 @@ def create_array(s, t): else: arrs.append(create_array(s, t)) - return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) def dump_stream(self, iterator, stream): """ diff --git a/python/pyspark/sql/pandas/typehints.py b/python/pyspark/sql/pandas/typehints.py index b0323ba1697df..e696f677cd154 100644 --- a/python/pyspark/sql/pandas/typehints.py +++ b/python/pyspark/sql/pandas/typehints.py @@ -98,8 +98,8 @@ def infer_eval_type(sig): a, parameter_check_func=lambda ua: ua == pd.Series or ua == pd.DataFrame) for a in parameters_sig) and ( - # It's tricky to whitelist which types pd.Series constructor can take. - # Simply blacklist common types used here for now (which becomes object + # It's tricky to include only types which pd.Series constructor can take. + # Simply exclude common types used here for now (which becomes object # types Spark can't recognize). return_annotation != pd.Series and return_annotation != pd.DataFrame and diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 336345e383729..a83aece2e485d 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -15,15 +15,9 @@ # limitations under the License. # -import sys - -if sys.version >= '3': - basestring = unicode = str - from py4j.java_gateway import JavaClass from pyspark import RDD, since -from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * from pyspark.sql import utils @@ -94,7 +88,7 @@ def schema(self, schema): if isinstance(schema, StructType): jschema = spark._jsparkSession.parseDataType(schema.json()) self._jreader = self._jreader.schema(jschema) - elif isinstance(schema, basestring): + elif isinstance(schema, str): self._jreader = self._jreader.schema(schema) else: raise TypeError("schema should be StructType or string") @@ -174,7 +168,7 @@ def load(self, path=None, format=None, schema=None, **options): if schema is not None: self.schema(schema) self.options(**options) - if isinstance(path, basestring): + if isinstance(path, str): return self._df(self._jreader.load(path)) elif path is not None: if type(path) != list: @@ -294,16 +288,16 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding, locale=locale, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup) - if isinstance(path, basestring): + if isinstance(path, str): path = [path] if type(path) == list: return self._df(self._jreader.json(self._spark._sc._jvm.PythonUtils.toSeq(path))) elif isinstance(path, RDD): def func(iterator): for x in iterator: - if not isinstance(x, basestring): - x = unicode(x) - if isinstance(x, unicode): + if not isinstance(x, str): + x = str(x) + if isinstance(x, str): x = x.encode("utf-8") yield x keyed = path.mapPartitions(func) @@ -352,7 +346,6 @@ def parquet(self, *paths, **options): recursiveFileLookup=recursiveFileLookup) return self._df(self._jreader.parquet(_to_seq(self._spark._sc, paths))) - @ignore_unicode_prefix @since(1.6) def text(self, paths, wholetext=False, lineSep=None, pathGlobFilter=None, recursiveFileLookup=None): @@ -376,15 +369,15 @@ def text(self, paths, wholetext=False, lineSep=None, pathGlobFilter=None, >>> df = spark.read.text('python/test_support/sql/text-test.txt') >>> df.collect() - [Row(value=u'hello'), Row(value=u'this')] + [Row(value='hello'), Row(value='this')] >>> df = spark.read.text('python/test_support/sql/text-test.txt', wholetext=True) >>> df.collect() - [Row(value=u'hello\\nthis')] + [Row(value='hello\\nthis')] """ self._set_opts( wholetext=wholetext, lineSep=lineSep, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup) - if isinstance(paths, basestring): + if isinstance(paths, str): paths = [paths] return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths))) @@ -529,16 +522,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale, lineSep=lineSep, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup) - if isinstance(path, basestring): + if isinstance(path, str): path = [path] if type(path) == list: return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) elif isinstance(path, RDD): def func(iterator): for x in iterator: - if not isinstance(x, basestring): - x = unicode(x) - if isinstance(x, unicode): + if not isinstance(x, str): + x = str(x) + if isinstance(x, str): x = x.encode("utf-8") yield x keyed = path.mapPartitions(func) @@ -574,7 +567,7 @@ def orc(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=N """ self._set_opts(mergeSchema=mergeSchema, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup) - if isinstance(path, basestring): + if isinstance(path, str): path = [path] return self._df(self._jreader.orc(_to_seq(self._spark._sc, path))) @@ -763,7 +756,7 @@ def bucketBy(self, numBuckets, col, *cols): col, cols = col[0], col[1:] - if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)): + if not all(isinstance(c, str) for c in cols) or not(isinstance(col, str)): raise TypeError("all names should be `str`") self._jwrite = self._jwrite.bucketBy(numBuckets, col, _to_seq(self._spark._sc, cols)) @@ -788,7 +781,7 @@ def sortBy(self, col, *cols): col, cols = col[0], col[1:] - if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)): + if not all(isinstance(c, str) for c in cols) or not(isinstance(col, str)): raise TypeError("all names should be `str`") self._jwrite = self._jwrite.sortBy(col, _to_seq(self._spark._sc, cols)) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 61891c478dbe4..a5d102712d5e4 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -15,22 +15,13 @@ # limitations under the License. # -# To disallow implicit relative import. Remove this once we drop Python 2. -from __future__ import absolute_import -from __future__ import print_function import sys import warnings from functools import reduce from threading import RLock -if sys.version >= '3': - basestring = unicode = str - xrange = range -else: - from itertools import imap as map - from pyspark import since -from pyspark.rdd import RDD, ignore_unicode_prefix +from pyspark.rdd import RDD from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame from pyspark.sql.pandas.conversion import SparkConversionMixin @@ -56,7 +47,7 @@ def toDF(self, schema=None, sampleRatio=None): :return: a DataFrame >>> rdd.toDF().collect() - [Row(name=u'Alice', age=1)] + [Row(name='Alice', age=1)] """ return sparkSession.createDataFrame(self, schema, sampleRatio) @@ -197,7 +188,6 @@ def getOrCreate(self): _instantiatedSession = None _activeSession = None - @ignore_unicode_prefix def __init__(self, sparkContext, jsparkSession=None): """Creates a new SparkSession. @@ -213,7 +203,7 @@ def __init__(self, sparkContext, jsparkSession=None): [Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \ dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] >>> df.rdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() - [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] + [(1, 'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ from pyspark.sql.context import SQLContext self._sc = sparkContext @@ -492,7 +482,6 @@ def _create_shell_session(): return SparkSession.builder.getOrCreate() @since(2.0) - @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): """ Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. @@ -530,34 +519,29 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr .. note:: Usage with spark.sql.execution.arrow.pyspark.enabled=True is experimental. - .. note:: When Arrow optimization is enabled, strings inside Pandas DataFrame in Python - 2 are converted into bytes as they are bytes in Python 2 whereas regular strings are - left as strings. When using strings in Python 2, use unicode `u""` as Python standard - practice. - >>> l = [('Alice', 1)] >>> spark.createDataFrame(l).collect() - [Row(_1=u'Alice', _2=1)] + [Row(_1='Alice', _2=1)] >>> spark.createDataFrame(l, ['name', 'age']).collect() - [Row(name=u'Alice', age=1)] + [Row(name='Alice', age=1)] >>> d = [{'name': 'Alice', 'age': 1}] >>> spark.createDataFrame(d).collect() - [Row(age=1, name=u'Alice')] + [Row(age=1, name='Alice')] >>> rdd = sc.parallelize(l) >>> spark.createDataFrame(rdd).collect() - [Row(_1=u'Alice', _2=1)] + [Row(_1='Alice', _2=1)] >>> df = spark.createDataFrame(rdd, ['name', 'age']) >>> df.collect() - [Row(name=u'Alice', age=1)] + [Row(name='Alice', age=1)] >>> from pyspark.sql import Row >>> Person = Row('name', 'age') >>> person = rdd.map(lambda r: Person(*r)) >>> df2 = spark.createDataFrame(person) >>> df2.collect() - [Row(name=u'Alice', age=1)] + [Row(name='Alice', age=1)] >>> from pyspark.sql.types import * >>> schema = StructType([ @@ -565,15 +549,15 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr ... StructField("age", IntegerType(), True)]) >>> df3 = spark.createDataFrame(rdd, schema) >>> df3.collect() - [Row(name=u'Alice', age=1)] + [Row(name='Alice', age=1)] >>> spark.createDataFrame(df.toPandas()).collect() # doctest: +SKIP - [Row(name=u'Alice', age=1)] + [Row(name='Alice', age=1)] >>> spark.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP [Row(0=1, 1=2)] >>> spark.createDataFrame(rdd, "a: string, b: int").collect() - [Row(a=u'Alice', b=1)] + [Row(a='Alice', b=1)] >>> rdd = rdd.map(lambda row: row[1]) >>> spark.createDataFrame(rdd, "int").collect() [Row(value=1)] @@ -587,7 +571,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") - if isinstance(schema, basestring): + if isinstance(schema, str): schema = _parse_datatype_string(schema) elif isinstance(schema, (list, tuple)): # Must re-encode any unicode strings to be consistent with StructField names @@ -634,7 +618,6 @@ def prepare(obj): df._schema = schema return df - @ignore_unicode_prefix @since(2.0) def sql(self, sqlQuery): """Returns a :class:`DataFrame` representing the result of the given query. @@ -644,7 +627,7 @@ def sql(self, sqlQuery): >>> df.createOrReplaceTempView("table1") >>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() - [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] + [Row(f1=1, f2='row1'), Row(f1=2, f2='row2'), Row(f1=3, f2='row3')] """ return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 2450a4c93c460..5c528c1d54df7 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -18,13 +18,9 @@ import sys import json -if sys.version >= '3': - basestring = str - from py4j.java_gateway import java_import from pyspark import since, keyword_only -from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.readwriter import OptionUtils, to_str from pyspark.sql.types import * @@ -204,7 +200,6 @@ def __init__(self, jsqm): self._jsqm = jsqm @property - @ignore_unicode_prefix @since(2.0) def active(self): """Returns a list of active queries associated with this SQLContext @@ -213,12 +208,11 @@ def active(self): >>> sqm = spark.streams >>> # get the list of active streaming queries >>> [q.name for q in sqm.active] - [u'this_query'] + ['this_query'] >>> sq.stop() """ return [StreamingQuery(jsq) for jsq in self._jsqm.active()] - @ignore_unicode_prefix @since(2.0) def get(self, id): """Returns an active query from this SQLContext or throws exception if an active query @@ -226,7 +220,7 @@ def get(self, id): >>> sq = sdf.writeStream.format('memory').queryName('this_query').start() >>> sq.name - u'this_query' + 'this_query' >>> sq = spark.streams.get(sq.id) >>> sq.isActive True @@ -328,7 +322,7 @@ def schema(self, schema): if isinstance(schema, StructType): jschema = spark._jsparkSession.parseDataType(schema.json()) self._jreader = self._jreader.schema(jschema) - elif isinstance(schema, basestring): + elif isinstance(schema, str): self._jreader = self._jreader.schema(schema) else: raise TypeError("schema should be StructType or string") @@ -527,7 +521,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, locale=locale, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup) - if isinstance(path, basestring): + if isinstance(path, str): return self._df(self._jreader.json(path)) else: raise TypeError("path can be only a single string") @@ -555,7 +549,7 @@ def orc(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLookup=N """ self._set_opts(mergeSchema=mergeSchema, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup) - if isinstance(path, basestring): + if isinstance(path, str): return self._df(self._jreader.orc(path)) else: raise TypeError("path can be only a single string") @@ -585,12 +579,11 @@ def parquet(self, path, mergeSchema=None, pathGlobFilter=None, recursiveFileLook """ self._set_opts(mergeSchema=mergeSchema, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup) - if isinstance(path, basestring): + if isinstance(path, str): return self._df(self._jreader.parquet(path)) else: raise TypeError("path can be only a single string") - @ignore_unicode_prefix @since(2.0) def text(self, path, wholetext=False, lineSep=None, pathGlobFilter=None, recursiveFileLookup=None): @@ -623,7 +616,7 @@ def text(self, path, wholetext=False, lineSep=None, pathGlobFilter=None, self._set_opts( wholetext=wholetext, lineSep=lineSep, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup) - if isinstance(path, basestring): + if isinstance(path, str): return self._df(self._jreader.text(path)) else: raise TypeError("path can be only a single string") @@ -762,7 +755,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale, lineSep=lineSep, pathGlobFilter=pathGlobFilter, recursiveFileLookup=recursiveFileLookup) - if isinstance(path, basestring): + if isinstance(path, str): return self._df(self._jreader.csv(path)) else: raise TypeError("path can be only a single string") @@ -1153,7 +1146,6 @@ def foreachBatch(self, func): ensure_callback_server_started(gw) return self - @ignore_unicode_prefix @since(2.0) def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None, **options): @@ -1186,14 +1178,14 @@ def start(self, path=None, format=None, outputMode=None, partitionBy=None, query >>> sq.isActive True >>> sq.name - u'this_query' + 'this_query' >>> sq.stop() >>> sq.isActive False >>> sq = sdf.writeStream.trigger(processingTime='5 seconds').start( ... queryName='that_query', outputMode="append", format='memory') >>> sq.name - u'that_query' + 'that_query' >>> sq.isActive True >>> sq.stop() diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 913b43b6ddb5a..148df9b7d45b8 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -29,7 +29,6 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message from pyspark.testing.utils import QuietTest -from pyspark.util import _exception_message if have_pandas: import pandas as pd @@ -127,7 +126,7 @@ def test_toPandas_fallback_enabled(self): warn.message for warn in warns if isinstance(warn.message, UserWarning)] self.assertTrue(len(user_warns) > 0) self.assertTrue( - "Attempting non-optimization" in _exception_message(user_warns[-1])) + "Attempting non-optimization" in str(user_warns[-1])) assert_frame_equal(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) def test_toPandas_fallback_disabled(self): @@ -355,7 +354,7 @@ def test_createDataFrame_fallback_enabled(self): warn.message for warn in warns if isinstance(warn.message, UserWarning)] self.assertTrue(len(user_warns) > 0) self.assertTrue( - "Attempting non-optimization" in _exception_message(user_warns[-1])) + "Attempting non-optimization" in str(user_warns[-1])) self.assertEqual(df.collect(), [Row(a={u'a': 1})]) def test_createDataFrame_fallback_disabled(self): @@ -448,6 +447,13 @@ def test_createDataFrame_with_float_index(self): self.spark.createDataFrame( pd.DataFrame({'a': [1, 2, 3]}, index=[2., 3., 4.])).distinct().count(), 3) + def test_no_partition_toPandas(self): + # SPARK-32301: toPandas should work from a Spark DataFrame with no partitions + # Forward-ported from SPARK-32300. + pdf = self.spark.sparkContext.emptyRDD().toDF("col1 int").toPandas() + self.assertEqual(len(pdf), 0) + self.assertEqual(list(pdf.columns), ["col1"]) + @unittest.skipIf( not have_pandas or not have_pyarrow, diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 58bf896a10c2a..e0b8bf45a2c70 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -16,8 +16,6 @@ # limitations under the License. # -import sys - from pyspark.sql import Column, Row from pyspark.sql.types import * from pyspark.sql.utils import AnalysisException @@ -109,12 +107,8 @@ def test_access_column(self): self.assertRaises(TypeError, lambda: df[{}]) def test_column_name_with_non_ascii(self): - if sys.version >= '3': - columnName = "数量" - self.assertTrue(isinstance(columnName, str)) - else: - columnName = unicode("数量", "utf-8") - self.assertTrue(isinstance(columnName, unicode)) + columnName = "数量" + self.assertTrue(isinstance(columnName, str)) schema = StructType([StructField(columnName, LongType(), True)]) df = self.spark.createDataFrame([(1,)], schema) self.assertEqual(schema, df.schema) diff --git a/python/pyspark/sql/tests/test_context.py b/python/pyspark/sql/tests/test_context.py index 3b1b638ed4aa6..ff953ba4b4b76 100644 --- a/python/pyspark/sql/tests/test_context.py +++ b/python/pyspark/sql/tests/test_context.py @@ -19,11 +19,7 @@ import sys import tempfile import unittest -try: - from importlib import reload # Python 3.4+ only. -except ImportError: - # Otherwise, we will stick to Python 2's built-in reload. - pass +from importlib import reload import py4j diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 52ae74df5d4f2..7dcc19f3ba45d 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -167,10 +167,6 @@ def test_string_functions(self): TypeError, "must be the same type", lambda: df.select(col('name').substr(0, lit(1)))) - if sys.version_info.major == 2: - self.assertRaises( - TypeError, - lambda: df.select(col('name').substr(long(0), long(1)))) for name in _string_functions.keys(): self.assertEqual( diff --git a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py index c1cb30c3caa91..24a73918d8be4 100644 --- a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py @@ -32,11 +32,6 @@ import pyarrow as pa -# Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names -# From kwargs w/ Python 2, so need to set check_column_type=False and avoid this check -_check_column_type = sys.version >= '3' - - @unittest.skipIf( not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message) @@ -109,7 +104,7 @@ def merge_pandas(l, r): 'v2': [90, 100, 110] }) - assert_frame_equal(expected, result, check_column_type=_check_column_type) + assert_frame_equal(expected, result) def test_empty_group_by(self): left = self.data1 @@ -130,7 +125,7 @@ def merge_pandas(l, r): .merge(left, right, on=['id', 'k']) \ .sort_values(by=['id', 'k']) - assert_frame_equal(expected, result, check_column_type=_check_column_type) + assert_frame_equal(expected, result) def test_mixed_scalar_udfs_followed_by_cogrouby_apply(self): df = self.spark.range(0, 10).toDF('v1') @@ -173,7 +168,7 @@ def left_assign_key(key, l, _): expected = self.data1.toPandas() expected = expected.assign(key=expected.id % 2 == 0) - assert_frame_equal(expected, result, check_column_type=_check_column_type) + assert_frame_equal(expected, result) def test_wrong_return_type(self): # Test that we get a sensible exception invalid values passed to apply @@ -224,7 +219,7 @@ def right_assign_key(key, l, r): expected = left.toPandas() if isLeft else right.toPandas() expected = expected.assign(key=expected.id) - assert_frame_equal(expected, result, check_column_type=_check_column_type) + assert_frame_equal(expected, result) @staticmethod def _test_merge(left, right, output_schema='id long, k int, v int, v2 int'): @@ -246,7 +241,7 @@ def merge_pandas(l, r): .merge(left, right, on=['id', 'k']) \ .sort_values(by=['id', 'k']) - assert_frame_equal(expected, result, check_column_type=_check_column_type) + assert_frame_equal(expected, result) if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map.py b/python/pyspark/sql/tests/test_pandas_grouped_map.py index cc6167e619285..00cc9b3a64c73 100644 --- a/python/pyspark/sql/tests/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_grouped_map.py @@ -38,11 +38,6 @@ import pyarrow as pa -# Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names -# from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check -_check_column_type = sys.version >= '3' - - @unittest.skipIf( not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message) @@ -139,9 +134,9 @@ def test_supported_types(self): result3 = df.groupby('id').apply(udf3).sort('id').toPandas() expected3 = expected1 - assert_frame_equal(expected1, result1, check_column_type=_check_column_type) - assert_frame_equal(expected2, result2, check_column_type=_check_column_type) - assert_frame_equal(expected3, result3, check_column_type=_check_column_type) + assert_frame_equal(expected1, result1) + assert_frame_equal(expected2, result2) + assert_frame_equal(expected3, result3) def test_array_type_correct(self): df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id") @@ -159,7 +154,7 @@ def test_array_type_correct(self): result = df.groupby('id').apply(udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True) - assert_frame_equal(expected, result, check_column_type=_check_column_type) + assert_frame_equal(expected, result) def test_register_grouped_map_udf(self): foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP) @@ -181,7 +176,7 @@ def foo(pdf): result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) - assert_frame_equal(expected, result, check_column_type=_check_column_type) + assert_frame_equal(expected, result) def test_coerce(self): df = self.data @@ -195,7 +190,7 @@ def test_coerce(self): result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) expected = expected.assign(v=expected.v.astype('float64')) - assert_frame_equal(expected, result, check_column_type=_check_column_type) + assert_frame_equal(expected, result) def test_complex_groupby(self): df = self.data @@ -213,7 +208,7 @@ def normalize(pdf): expected = pdf.groupby(pdf['id'] % 2 == 0, as_index=False).apply(normalize.func) expected = expected.sort_values(['id', 'v']).reset_index(drop=True) expected = expected.assign(norm=expected.norm.astype('float64')) - assert_frame_equal(expected, result, check_column_type=_check_column_type) + assert_frame_equal(expected, result) def test_empty_groupby(self): df = self.data @@ -231,7 +226,7 @@ def normalize(pdf): expected = normalize.func(pdf) expected = expected.sort_values(['id', 'v']).reset_index(drop=True) expected = expected.assign(norm=expected.norm.astype('float64')) - assert_frame_equal(expected, result, check_column_type=_check_column_type) + assert_frame_equal(expected, result) def test_datatype_string(self): df = self.data @@ -244,7 +239,7 @@ def test_datatype_string(self): result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - assert_frame_equal(expected, result, check_column_type=_check_column_type) + assert_frame_equal(expected, result) def test_wrong_return_type(self): with QuietTest(self.sc): @@ -301,7 +296,7 @@ def test_timestamp_dst(self): df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') foo_udf = pandas_udf(lambda pdf: pdf, 'time timestamp', PandasUDFType.GROUPED_MAP) result = df.groupby('time').apply(foo_udf).sort('time') - assert_frame_equal(df.toPandas(), result.toPandas(), check_column_type=_check_column_type) + assert_frame_equal(df.toPandas(), result.toPandas()) def test_udf_with_key(self): import numpy as np @@ -355,26 +350,26 @@ def foo3(key, pdf): expected1 = pdf.groupby('id', as_index=False)\ .apply(lambda x: udf1.func((x.id.iloc[0],), x))\ .sort_values(['id', 'v']).reset_index(drop=True) - assert_frame_equal(expected1, result1, check_column_type=_check_column_type) + assert_frame_equal(expected1, result1) # Test groupby expression result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas() expected2 = pdf.groupby(pdf.id % 2, as_index=False)\ .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\ .sort_values(['id', 'v']).reset_index(drop=True) - assert_frame_equal(expected2, result2, check_column_type=_check_column_type) + assert_frame_equal(expected2, result2) # Test complex groupby result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas() expected3 = pdf.groupby([pdf.id, pdf.v % 2], as_index=False)\ .apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\ .sort_values(['id', 'v']).reset_index(drop=True) - assert_frame_equal(expected3, result3, check_column_type=_check_column_type) + assert_frame_equal(expected3, result3) # Test empty groupby result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas() expected4 = udf3.func((), pdf) - assert_frame_equal(expected4, result4, check_column_type=_check_column_type) + assert_frame_equal(expected4, result4) def test_column_order(self): @@ -407,7 +402,7 @@ def change_col_order(pdf): .select('id', 'u', 'v').toPandas() pd_result = grouped_pdf.apply(change_col_order) expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) - assert_frame_equal(expected, result, check_column_type=_check_column_type) + assert_frame_equal(expected, result) # Function returns a pdf with positional columns, indexed by range def range_col_order(pdf): @@ -426,7 +421,7 @@ def range_col_order(pdf): pd_result = grouped_pdf.apply(range_col_order) rename_pdf(pd_result, ['id', 'u', 'v']) expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) - assert_frame_equal(expected, result, check_column_type=_check_column_type) + assert_frame_equal(expected, result) # Function returns a pdf with columns indexed with integers def int_index(pdf): @@ -444,7 +439,7 @@ def int_index(pdf): pd_result = grouped_pdf.apply(int_index) rename_pdf(pd_result, ['id', 'u', 'v']) expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) - assert_frame_equal(expected, result, check_column_type=_check_column_type) + assert_frame_equal(expected, result) @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) def column_name_typo(pdf): diff --git a/python/pyspark/sql/tests/test_pandas_map.py b/python/pyspark/sql/tests/test_pandas_map.py index f1956a2523e48..02ae6a86f9ab3 100644 --- a/python/pyspark/sql/tests/test_pandas_map.py +++ b/python/pyspark/sql/tests/test_pandas_map.py @@ -19,9 +19,6 @@ import time import unittest -if sys.version >= '3': - unicode = str - from pyspark.sql.functions import pandas_udf, PandasUDFType from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 2d38efd39f902..448e409b0c377 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import datetime import os import random import shutil @@ -22,10 +21,6 @@ import tempfile import time import unittest - -if sys.version >= '3': - unicode = str - from datetime import date, datetime from decimal import Decimal @@ -319,7 +314,7 @@ def test_vectorized_udf_struct_type(self): StructField('str', StringType())]) def scalar_func(id): - return pd.DataFrame({'id': id, 'str': id.apply(unicode)}) + return pd.DataFrame({'id': id, 'str': id.apply(str)}) def iter_func(it): for id in it: @@ -486,14 +481,14 @@ def test_vectorized_udf_chained_struct_type(self): @pandas_udf(return_type) def scalar_f(id): - return pd.DataFrame({'id': id, 'str': id.apply(unicode)}) + return pd.DataFrame({'id': id, 'str': id.apply(str)}) scalar_g = pandas_udf(lambda x: x, return_type) @pandas_udf(return_type, PandasUDFType.SCALAR_ITER) def iter_f(it): for id in it: - yield pd.DataFrame({'id': id, 'str': id.apply(unicode)}) + yield pd.DataFrame({'id': id, 'str': id.apply(str)}) iter_g = pandas_udf(lambda x: x, return_type, PandasUDFType.SCALAR_ITER) @@ -915,21 +910,12 @@ def to_category_func(x): # Check result of column 'B' must be equal to column 'A' in type and values pd.testing.assert_series_equal(result_spark["A"], result_spark["B"], check_names=False) - @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.") def test_type_annotation(self): # Regression test to check if type hints can be used. See SPARK-23569. - # Note that it throws an error during compilation in lower Python versions if 'exec' - # is not used. Also, note that we explicitly use another dictionary to avoid modifications - # in the current 'locals()'. - # - # Hyukjin: I think it's an ugly way to test issues about syntax specific in - # higher versions of Python, which we shouldn't encourage. This was the last resort - # I could come up with at that time. - _locals = {} - exec( - "import pandas as pd\ndef noop(col: pd.Series) -> pd.Series: return col", - _locals) - df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) + def noop(col: pd.Series) -> pd.Series: + return col + + df = self.spark.range(1).select(pandas_udf(f=noop, returnType='bigint')('id')) self.assertEqual(df.first()[0], 0) def test_mixed_udf(self): diff --git a/python/pyspark/sql/tests/test_pandas_udf_typehints.py b/python/pyspark/sql/tests/test_pandas_udf_typehints.py index 2582080056864..618164fa8496f 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_typehints.py +++ b/python/pyspark/sql/tests/test_pandas_udf_typehints.py @@ -14,9 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import sys import unittest import inspect +from typing import Union, Iterator, Tuple from pyspark.sql.functions import mean, lit from pyspark.testing.sqlutils import ReusedSQLTestCase, \ @@ -24,209 +24,162 @@ pyarrow_requirement_message from pyspark.sql.pandas.typehints import infer_eval_type from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType +from pyspark.sql import Row if have_pandas: import pandas as pd + import numpy as np from pandas.util.testing import assert_frame_equal -python_requirement_message = "pandas UDF with type hints are supported with Python 3.6+." - @unittest.skipIf( - not have_pandas or not have_pyarrow or sys.version_info[:2] < (3, 6), - pandas_requirement_message or pyarrow_requirement_message or python_requirement_message) + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message) class PandasUDFTypeHintsTests(ReusedSQLTestCase): - # Note that, we should remove `exec` once we drop Python 2 in this class. - - def setUp(self): - self.local = {'pd': pd} - def test_type_annotation_scalar(self): - exec( - "def func(col: pd.Series) -> pd.Series: pass", - self.local) + def func(col: pd.Series) -> pd.Series: + pass self.assertEqual( - infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR) + infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR) - exec( - "def func(col: pd.DataFrame, col1: pd.Series) -> pd.DataFrame: pass", - self.local) + def func(col: pd.DataFrame, col1: pd.Series) -> pd.DataFrame: + pass self.assertEqual( - infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR) + infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR) - exec( - "def func(col: pd.DataFrame, *args: pd.Series) -> pd.Series: pass", - self.local) + def func(col: pd.DataFrame, *args: pd.Series) -> pd.Series: + pass self.assertEqual( - infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR) + infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR) - exec( - "def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame) -> pd.Series:\n" - " pass", - self.local) + def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame) -> pd.Series: + pass self.assertEqual( - infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR) + infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR) - exec( - "def func(col: pd.Series, *, col2: pd.DataFrame) -> pd.DataFrame:\n" - " pass", - self.local) + def func(col: pd.Series, *, col2: pd.DataFrame) -> pd.DataFrame: + pass self.assertEqual( - infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR) + infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR) - exec( - "from typing import Union\n" - "def func(col: Union[pd.Series, pd.DataFrame], *, col2: pd.DataFrame) -> pd.Series:\n" - " pass", - self.local) + def func(col: Union[pd.Series, pd.DataFrame], *, col2: pd.DataFrame) -> pd.Series: + pass self.assertEqual( - infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR) + infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR) def test_type_annotation_scalar_iter(self): - exec( - "from typing import Iterator\n" - "def func(iter: Iterator[pd.Series]) -> Iterator[pd.Series]: pass", - self.local) + def func(iter: Iterator[pd.Series]) -> Iterator[pd.Series]: + pass self.assertEqual( - infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR_ITER) + infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR_ITER) - exec( - "from typing import Iterator, Tuple\n" - "def func(iter: Iterator[Tuple[pd.DataFrame, pd.Series]]) -> Iterator[pd.DataFrame]:\n" - " pass", - self.local) + def func(iter: Iterator[Tuple[pd.DataFrame, pd.Series]]) -> Iterator[pd.DataFrame]: + pass self.assertEqual( - infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR_ITER) + infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR_ITER) - exec( - "from typing import Iterator, Tuple\n" - "def func(iter: Iterator[Tuple[pd.DataFrame, ...]]) -> Iterator[pd.Series]: pass", - self.local) + def func(iter: Iterator[Tuple[pd.DataFrame, ...]]) -> Iterator[pd.Series]: + pass self.assertEqual( - infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR_ITER) + infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR_ITER) - exec( - "from typing import Iterator, Tuple, Union\n" - "def func(iter: Iterator[Tuple[Union[pd.DataFrame, pd.Series], ...]])" - " -> Iterator[pd.Series]: pass", - self.local) + def func( + iter: Iterator[Tuple[Union[pd.DataFrame, pd.Series], ...]] + ) -> Iterator[pd.Series]: + pass self.assertEqual( - infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR_ITER) + infer_eval_type(inspect.signature(func)), PandasUDFType.SCALAR_ITER) def test_type_annotation_group_agg(self): - exec( - "def func(col: pd.Series) -> str: pass", - self.local) + + def func(col: pd.Series) -> str: + pass self.assertEqual( - infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG) + infer_eval_type(inspect.signature(func)), PandasUDFType.GROUPED_AGG) - exec( - "def func(col: pd.DataFrame, col1: pd.Series) -> int: pass", - self.local) + def func(col: pd.DataFrame, col1: pd.Series) -> int: + pass self.assertEqual( - infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG) + infer_eval_type(inspect.signature(func)), PandasUDFType.GROUPED_AGG) - exec( - "from pyspark.sql import Row\n" - "def func(col: pd.DataFrame, *args: pd.Series) -> Row: pass", - self.local) + def func(col: pd.DataFrame, *args: pd.Series) -> Row: + pass self.assertEqual( - infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG) + infer_eval_type(inspect.signature(func)), PandasUDFType.GROUPED_AGG) - exec( - "def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame) -> str:\n" - " pass", - self.local) + def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame) -> str: + pass self.assertEqual( - infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG) + infer_eval_type(inspect.signature(func)), PandasUDFType.GROUPED_AGG) - exec( - "def func(col: pd.Series, *, col2: pd.DataFrame) -> float:\n" - " pass", - self.local) + def func(col: pd.Series, *, col2: pd.DataFrame) -> float: + pass self.assertEqual( - infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG) + infer_eval_type(inspect.signature(func)), PandasUDFType.GROUPED_AGG) - exec( - "from typing import Union\n" - "def func(col: Union[pd.Series, pd.DataFrame], *, col2: pd.DataFrame) -> float:\n" - " pass", - self.local) + def func(col: Union[pd.Series, pd.DataFrame], *, col2: pd.DataFrame) -> float: + pass self.assertEqual( - infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG) + infer_eval_type(inspect.signature(func)), PandasUDFType.GROUPED_AGG) def test_type_annotation_negative(self): - exec( - "def func(col: str) -> pd.Series: pass", - self.local) + + def func(col: str) -> pd.Series: + pass self.assertRaisesRegex( NotImplementedError, "Unsupported signature.*str", - infer_eval_type, inspect.signature(self.local['func'])) + infer_eval_type, inspect.signature(func)) - exec( - "def func(col: pd.DataFrame, col1: int) -> pd.DataFrame: pass", - self.local) + def func(col: pd.DataFrame, col1: int) -> pd.DataFrame: + pass self.assertRaisesRegex( NotImplementedError, "Unsupported signature.*int", - infer_eval_type, inspect.signature(self.local['func'])) + infer_eval_type, inspect.signature(func)) - exec( - "from typing import Union\n" - "def func(col: Union[pd.DataFrame, str], col1: int) -> pd.DataFrame: pass", - self.local) + def func(col: Union[pd.DataFrame, str], col1: int) -> pd.DataFrame: + pass self.assertRaisesRegex( NotImplementedError, "Unsupported signature.*str", - infer_eval_type, inspect.signature(self.local['func'])) + infer_eval_type, inspect.signature(func)) - exec( - "from typing import Tuple\n" - "def func(col: pd.Series) -> Tuple[pd.DataFrame]: pass", - self.local) + def func(col: pd.Series) -> Tuple[pd.DataFrame]: + pass self.assertRaisesRegex( NotImplementedError, "Unsupported signature.*Tuple", - infer_eval_type, inspect.signature(self.local['func'])) + infer_eval_type, inspect.signature(func)) - exec( - "def func(col, *args: pd.Series) -> pd.Series: pass", - self.local) + def func(col, *args: pd.Series) -> pd.Series: + pass self.assertRaisesRegex( ValueError, "should be specified.*Series", - infer_eval_type, inspect.signature(self.local['func'])) + infer_eval_type, inspect.signature(func)) - exec( - "def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame):\n" - " pass", - self.local) + def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame): + pass self.assertRaisesRegex( ValueError, "should be specified.*Series", - infer_eval_type, inspect.signature(self.local['func'])) + infer_eval_type, inspect.signature(func)) - exec( - "def func(col: pd.Series, *, col2) -> pd.DataFrame:\n" - " pass", - self.local) + def func(col: pd.Series, *, col2) -> pd.DataFrame: + pass self.assertRaisesRegex( ValueError, "should be specified.*Series", - infer_eval_type, inspect.signature(self.local['func'])) + infer_eval_type, inspect.signature(func)) def test_scalar_udf_type_hint(self): df = self.spark.range(10).selectExpr("id", "id as v") - exec( - "import typing\n" - "def plus_one(v: typing.Union[pd.Series, pd.DataFrame]) -> pd.Series:\n" - " return v + 1", - self.local) - - plus_one = pandas_udf("long")(self.local["plus_one"]) + def plus_one(v: Union[pd.Series, pd.DataFrame]) -> pd.Series: + return v + 1 + plus_one = pandas_udf("long")(plus_one) actual = df.select(plus_one(df.v).alias("plus_one")) expected = df.selectExpr("(v + 1) as plus_one") assert_frame_equal(expected.toPandas(), actual.toPandas()) @@ -234,14 +187,11 @@ def test_scalar_udf_type_hint(self): def test_scalar_iter_udf_type_hint(self): df = self.spark.range(10).selectExpr("id", "id as v") - exec( - "import typing\n" - "def plus_one(itr: typing.Iterator[pd.Series]) -> typing.Iterator[pd.Series]:\n" - " for s in itr:\n" - " yield s + 1", - self.local) + def plus_one(itr: Iterator[pd.Series]) -> Iterator[pd.Series]: + for s in itr: + yield s + 1 - plus_one = pandas_udf("long")(self.local["plus_one"]) + plus_one = pandas_udf("long")(plus_one) actual = df.select(plus_one(df.v).alias("plus_one")) expected = df.selectExpr("(v + 1) as plus_one") @@ -249,13 +199,11 @@ def test_scalar_iter_udf_type_hint(self): def test_group_agg_udf_type_hint(self): df = self.spark.range(10).selectExpr("id", "id as v") - exec( - "import numpy as np\n" - "def weighted_mean(v: pd.Series, w: pd.Series) -> float:\n" - " return np.average(v, weights=w)", - self.local) - weighted_mean = pandas_udf("double")(self.local["weighted_mean"]) + def weighted_mean(v: pd.Series, w: pd.Series) -> float: + return np.average(v, weights=w) + + weighted_mean = pandas_udf("double")(weighted_mean) actual = df.groupby('id').agg(weighted_mean(df.v, lit(1.0))).sort('id') expected = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id') @@ -263,12 +211,9 @@ def test_group_agg_udf_type_hint(self): def test_ignore_type_hint_in_group_apply_in_pandas(self): df = self.spark.range(10) - exec( - "def pandas_plus_one(v: pd.DataFrame) -> pd.DataFrame:\n" - " return v + 1", - self.local) - pandas_plus_one = self.local["pandas_plus_one"] + def pandas_plus_one(v: pd.DataFrame) -> pd.DataFrame: + return v + 1 actual = df.groupby('id').applyInPandas(pandas_plus_one, schema=df.schema).sort('id') expected = df.selectExpr("id + 1 as id") @@ -276,12 +221,9 @@ def test_ignore_type_hint_in_group_apply_in_pandas(self): def test_ignore_type_hint_in_cogroup_apply_in_pandas(self): df = self.spark.range(10) - exec( - "def pandas_plus_one(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:\n" - " return left + 1", - self.local) - pandas_plus_one = self.local["pandas_plus_one"] + def pandas_plus_one(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: + return left + 1 actual = df.groupby('id').cogroup( self.spark.range(10).groupby("id") @@ -291,13 +233,9 @@ def test_ignore_type_hint_in_cogroup_apply_in_pandas(self): def test_ignore_type_hint_in_map_in_pandas(self): df = self.spark.range(10) - exec( - "from typing import Iterator\n" - "def pandas_plus_one(iter: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:\n" - " return map(lambda v: v + 1, iter)", - self.local) - pandas_plus_one = self.local["pandas_plus_one"] + def pandas_plus_one(iter: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]: + return map(lambda v: v + 1, iter) actual = df.mapInPandas(pandas_plus_one, schema=df.schema) expected = df.selectExpr("id + 1 as id") diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 81402f52af3b3..051c8bde50ad9 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -56,7 +56,7 @@ def test_infer_schema_to_local(self): self.assertEqual(10, df3.count()) def test_apply_schema_to_dict_and_rows(self): - schema = StructType().add("b", StringType()).add("a", IntegerType()) + schema = StructType().add("a", IntegerType()).add("b", StringType()) input = [{"a": 1}, {"b": "coffee"}] rdd = self.sc.parallelize(input) for verify in [False, True]: @@ -72,7 +72,6 @@ def test_apply_schema_to_dict_and_rows(self): self.assertEqual(10, df4.count()) def test_create_dataframe_schema_mismatch(self): - input = [Row(a=1)] rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i)) schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())]) df = self.spark.createDataFrame(rdd, schema) @@ -540,7 +539,6 @@ def test_infer_long_type(self): self.assertEqual(_infer_type(2**61), LongType()) self.assertEqual(_infer_type(2**71), LongType()) - @unittest.skipIf(sys.version < "3", "only Python 3 infers bytes as binary type") def test_infer_binary_type(self): binaryrow = [Row(f1='a', f2=b"abcd")] df = self.sc.parallelize(binaryrow).toDF() @@ -665,10 +663,6 @@ def assertCollectSuccess(typecode, value): supported_string_types += ['u'] # test unicode assertCollectSuccess('u', u'a') - if sys.version_info[0] < 3: - supported_string_types += ['c'] - # test string - assertCollectSuccess('c', 'a') # supported float and double # @@ -721,11 +715,8 @@ def assertCollectSuccess(typecode, value): # # Keys in _array_type_mappings is a complete list of all supported types, # and types not in _array_type_mappings are considered unsupported. - # `array.typecodes` are not supported in python 2. - if sys.version_info[0] < 3: - all_types = set(['c', 'b', 'B', 'u', 'h', 'H', 'i', 'I', 'l', 'L', 'f', 'd']) - else: - all_types = set(array.typecodes) + # PyPy seems not having array.typecodes. + all_types = set(['b', 'B', 'u', 'h', 'H', 'i', 'I', 'l', 'L', 'q', 'Q', 'f', 'd']) unsupported_types = all_types - set(supported_types) # test unsupported types for t in unsupported_types: @@ -766,10 +757,7 @@ def test_row_without_column_name(self): self.assertEqual(repr(row), "") # test __repr__ with unicode values - if sys.version_info.major >= 3: - self.assertEqual(repr(Row("数", "量")), "") - else: - self.assertEqual(repr(Row(u"数", u"量")), r"") + self.assertEqual(repr(Row("数", "量")), "") def test_empty_row(self): row = Row() @@ -887,7 +875,6 @@ def __init__(self, **kwargs): ({"s": "a", "f": 1.0}, schema), (Row(s="a", i=1), schema), (Row(s="a", i=None), schema), - (Row(s="a", i=1, f=1.0), schema), (["a", 1], schema), (["a", None], schema), (("a", 1), schema), @@ -972,18 +959,13 @@ def __init__(self, **kwargs): with self.assertRaises(exp, msg=msg): _make_type_verifier(data_type, nullable=False)(obj) - @unittest.skipIf(sys.version_info[:2] < (3, 6), "Create Row without sorting fields") def test_row_without_field_sorting(self): - sorting_enabled_tmp = Row._row_field_sorting_enabled - Row._row_field_sorting_enabled = False - r = Row(b=1, a=2) TestRow = Row("b", "a") expected = TestRow(1, 2) self.assertEqual(r, expected) self.assertEqual(repr(r), "Row(b=1, a=2)") - Row._row_field_sorting_enabled = sorting_enabled_tmp if __name__ == "__main__": diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 320a68dffe7a3..cc08482c735b1 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -15,7 +15,6 @@ # limitations under the License. # -import os import sys import decimal import time @@ -26,11 +25,6 @@ import base64 from array import array import ctypes -import warnings - -if sys.version >= "3": - long = int - basestring = unicode = str from py4j.protocol import register_input_converter from py4j.java_gateway import JavaClass @@ -409,9 +403,7 @@ def __init__(self, name, dataType, nullable=True, metadata=None): """ assert isinstance(dataType, DataType),\ "dataType %s should be an instance of %s" % (dataType, DataType) - assert isinstance(name, basestring), "field name %s should be string" % (name) - if not isinstance(name, str): - name = name.encode('utf-8') + assert isinstance(name, str), "field name %s should be a string" % (name) self.name = name self.dataType = dataType self.nullable = nullable @@ -613,8 +605,6 @@ def toInternal(self, obj): else: if isinstance(obj, dict): return tuple(obj.get(n) for n in self.names) - elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): - return tuple(obj[n] for n in self.names) elif isinstance(obj, (list, tuple)): return tuple(obj) elif hasattr(obj, "__dict__"): @@ -904,19 +894,9 @@ def _parse_datatype_json_value(json_value): datetime.date: DateType, datetime.datetime: TimestampType, datetime.time: TimestampType, + bytes: BinaryType, } -if sys.version < "3": - _type_mappings.update({ - unicode: StringType, - long: LongType, - }) - -if sys.version >= "3": - _type_mappings.update({ - bytes: BinaryType, - }) - # Mapping Python array types to Spark SQL DataType # We should be careful here. The size of these types in python depends on C # implementation. We need to make sure that this conversion does not lose any @@ -990,20 +970,6 @@ def _int_size_to_type(size): if sys.version_info[0] < 4: _array_type_mappings['u'] = StringType -# Type code 'c' are only available at python 2 -if sys.version_info[0] < 3: - _array_type_mappings['c'] = StringType - -# SPARK-21465: -# In python2, array of 'L' happened to be mistakenly, just partially supported. To -# avoid breaking user's code, we should keep this partial support. Below is a -# dirty hacking to keep this partial support and pass the unit test. -import platform -if sys.version_info[0] < 3 and platform.python_implementation() != 'PyPy': - if 'L' not in _array_type_mappings.keys(): - _array_type_mappings['L'] = LongType - _array_unsigned_int_typecode_ctype_mappings['L'] = ctypes.c_uint - def _infer_type(obj): """Infer the DataType from obj @@ -1187,14 +1153,14 @@ def convert_struct(obj): _acceptable_types = { BooleanType: (bool,), - ByteType: (int, long), - ShortType: (int, long), - IntegerType: (int, long), - LongType: (int, long), + ByteType: (int,), + ShortType: (int,), + IntegerType: (int,), + LongType: (int,), FloatType: (float,), DoubleType: (float,), DecimalType: (decimal.Decimal,), - StringType: (str, unicode), + StringType: (str,), BinaryType: (bytearray, bytes), DateType: (datetime.date, datetime.datetime), TimestampType: (datetime.datetime,), @@ -1376,10 +1342,6 @@ def verify_struct(obj): if isinstance(obj, dict): for f, verifier in verifiers: verifier(obj.get(f)) - elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): - # the order in obj could be different than dataType.fields - for f, verifier in verifiers: - verifier(obj[f]) elif isinstance(obj, (tuple, list)): if len(obj) != len(verifiers): raise ValueError( @@ -1438,21 +1400,11 @@ class Row(tuple): NOTE: As of Spark 3.0.0, Rows created from named arguments no longer have field names sorted alphabetically and will be ordered in the position as - entered. To enable sorting for Rows compatible with Spark 2.x, set the - environment variable "PYSPARK_ROW_FIELD_SORTING_ENABLED" to "true". This - option is deprecated and will be removed in future versions of Spark. For - Python versions < 3.6, the order of named arguments is not guaranteed to - be the same as entered, see https://www.python.org/dev/peps/pep-0468. In - this case, a warning will be issued and the Row will fallback to sort the - field names automatically. - - NOTE: Examples with Row in pydocs are run with the environment variable - "PYSPARK_ROW_FIELD_SORTING_ENABLED" set to "true" which results in output - where fields are sorted. + entered. >>> row = Row(name="Alice", age=11) >>> row - Row(age=11, name='Alice') + Row(name='Alice', age=11) >>> row['name'], row['age'] ('Alice', 11) >>> row.name, row.age @@ -1476,47 +1428,22 @@ class Row(tuple): Row(name='Alice', age=11) This form can also be used to create rows as tuple values, i.e. with unnamed - fields. Beware that such Row objects have different equality semantics: + fields. >>> row1 = Row("Alice", 11) >>> row2 = Row(name="Alice", age=11) >>> row1 == row2 - False - >>> row3 = Row(a="Alice", b=11) - >>> row1 == row3 True """ - # Remove after Python < 3.6 dropped, see SPARK-29748 - _row_field_sorting_enabled = \ - os.environ.get('PYSPARK_ROW_FIELD_SORTING_ENABLED', 'false').lower() == 'true' - - if _row_field_sorting_enabled: - warnings.warn("The environment variable 'PYSPARK_ROW_FIELD_SORTING_ENABLED' " - "is deprecated and will be removed in future versions of Spark") - def __new__(cls, *args, **kwargs): if args and kwargs: raise ValueError("Can not use both args " "and kwargs to create Row") if kwargs: - if not Row._row_field_sorting_enabled and sys.version_info[:2] < (3, 6): - warnings.warn("To use named arguments for Python version < 3.6, Row fields will be " - "automatically sorted. This warning can be skipped by setting the " - "environment variable 'PYSPARK_ROW_FIELD_SORTING_ENABLED' to 'true'.") - Row._row_field_sorting_enabled = True - # create row objects - if Row._row_field_sorting_enabled: - # Remove after Python < 3.6 dropped, see SPARK-29748 - names = sorted(kwargs.keys()) - row = tuple.__new__(cls, [kwargs[n] for n in names]) - row.__fields__ = names - row.__from_dict__ = True - else: - row = tuple.__new__(cls, list(kwargs.values())) - row.__fields__ = list(kwargs.keys()) - + row = tuple.__new__(cls, list(kwargs.values())) + row.__fields__ = list(kwargs.keys()) return row else: # create row class or objects @@ -1537,7 +1464,7 @@ def asDict(self, recursive=False): >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11} True >>> row = Row(key=1, value=Row(name='a', age=2)) - >>> row.asDict() == {'key': 1, 'value': Row(age=2, name='a')} + >>> row.asDict() == {'key': 1, 'value': Row(name='a', age=2)} True >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}} True @@ -1600,7 +1527,7 @@ def __getattr__(self, item): raise AttributeError(item) def __setattr__(self, key, value): - if key != '__fields__' and key != "__from_dict__": + if key != '__fields__': raise Exception("Row is read-only") self.__dict__[key] = value diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index da68583b04e1c..100481cf12899 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -21,7 +21,7 @@ import sys from pyspark import SparkContext, since -from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix +from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string from pyspark.sql.pandas.types import to_arrow_type @@ -232,7 +232,6 @@ class UDFRegistration(object): def __init__(self, sparkSession): self.sparkSession = sparkSession - @ignore_unicode_prefix @since("1.3.1") def register(self, name, f, returnType=None): """Register a Python function (including lambda function) or a user-defined function @@ -261,10 +260,10 @@ def register(self, name, f, returnType=None): >>> strlen = spark.udf.register("stringLengthString", lambda x: len(x)) >>> spark.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] + [Row(stringLengthString(test)='4')] >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() - [Row(stringLengthString(text)=u'3')] + [Row(stringLengthString(text)='3')] >>> from pyspark.sql.types import IntegerType >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) @@ -349,7 +348,6 @@ def register(self, name, f, returnType=None): self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf) return return_udf - @ignore_unicode_prefix @since(2.3) def registerJavaFunction(self, name, javaClassName, returnType=None): """Register a Java user-defined function as a SQL function. @@ -389,7 +387,6 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) - @ignore_unicode_prefix @since(2.3) def registerJavaUDAF(self, name, javaClassName): """Register a Java user-defined aggregate function as a SQL function. @@ -403,7 +400,7 @@ def registerJavaUDAF(self, name, javaClassName): >>> df.createOrReplaceTempView("df") >>> q = "SELECT name, javaUDAF(id) as avg from df group by name order by name desc" >>> spark.sql(q).collect() # doctest: +SKIP - [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] + [Row(name='b', avg=102.0), Row(name='a', avg=102.0)] """ self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) @@ -419,9 +416,6 @@ def _test(): .appName("sql.udf tests")\ .getOrCreate() globs['spark'] = spark - # Hack to skip the unit tests in register. These are currently being tested in proper tests. - # We should reenable this test once we completely drop Python 2. - del pyspark.sql.udf.UDFRegistration.register (failure_count, test_count) = doctest.testmod( pyspark.sql.udf, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 1d5bc49d252e2..bd76d880055cd 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -16,22 +16,9 @@ # import py4j -import sys from pyspark import SparkContext -if sys.version_info.major >= 3: - unicode = str - # Disable exception chaining (PEP 3134) in captured exceptions - # in order to hide JVM stacktace. - exec(""" -def raise_from(e): - raise e from None -""") -else: - def raise_from(e): - raise e - class CapturedException(Exception): def __init__(self, desc, stackTrace, cause=None): @@ -45,11 +32,7 @@ def __str__(self): desc = self.desc if debug_enabled: desc = desc + "\n\nJVM stacktrace:\n%s" % self.stackTrace - # encode unicode instance for python2 for human readable description - if sys.version_info.major < 3 and isinstance(desc, unicode): - return str(desc.encode('utf-8')) - else: - return str(desc) + return str(desc) class AnalysisException(CapturedException): @@ -131,7 +114,7 @@ def deco(*a, **kw): if not isinstance(converted, UnknownException): # Hide where the exception came from that shows a non-Pythonic # JVM exception message. - raise_from(converted) + raise converted from None else: raise return deco diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 6199611940dc9..170f0c0ef7593 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - from py4j.java_gateway import java_import, is_instance_of from pyspark import RDD, SparkConf diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 60562a6c92aff..000318588ef88 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -21,11 +21,6 @@ from itertools import chain from datetime import datetime -if sys.version < "3": - from itertools import imap as map, ifilter as filter -else: - long = int - from py4j.protocol import Py4JJavaError from pyspark import RDD @@ -404,7 +399,7 @@ def _jtime(self, timestamp): """ if isinstance(timestamp, datetime): timestamp = time.mktime(timestamp.timetuple()) - return self._sc._jvm.Time(long(timestamp * 1000)) + return self._sc._jvm.Time(int(timestamp * 1000)) def slice(self, begin, end): """ diff --git a/python/pyspark/streaming/tests/test_dstream.py b/python/pyspark/streaming/tests/test_dstream.py index 7ecdf6b0b12db..89edb23070c69 100644 --- a/python/pyspark/streaming/tests/test_dstream.py +++ b/python/pyspark/streaming/tests/test_dstream.py @@ -30,8 +30,9 @@ @unittest.skipIf( - "pypy" in platform.python_implementation().lower() and "COVERAGE_PROCESS_START" in os.environ, - "PyPy implementation causes to hang DStream tests forever when Coverage report is used.") + "pypy" in platform.python_implementation().lower(), + "The tests fail in PyPy3 implementation for an unknown reason. " + "With PyPy, it causes to hang DStream tests forever when Coverage report is used.") class BasicOperationTests(PySparkStreamingTestCase): def test_map(self): @@ -394,8 +395,9 @@ def failed_func(i): @unittest.skipIf( - "pypy" in platform.python_implementation().lower() and "COVERAGE_PROCESS_START" in os.environ, - "PyPy implementation causes to hang DStream tests forever when Coverage report is used.") + "pypy" in platform.python_implementation().lower(), + "The tests fail in PyPy3 implementation for an unknown reason. " + "With PyPy, it causes to hang DStream tests forever when Coverage report is used.") class WindowFunctionTests(PySparkStreamingTestCase): timeout = 15 @@ -474,8 +476,9 @@ def func(dstream): @unittest.skipIf( - "pypy" in platform.python_implementation().lower() and "COVERAGE_PROCESS_START" in os.environ, - "PyPy implementation causes to hang DStream tests forever when Coverage report is used.") + "pypy" in platform.python_implementation().lower(), + "The tests fail in PyPy3 implementation for an unknown reason. " + "With PyPy, it causes to hang DStream tests forever when Coverage report is used.") class CheckpointTests(unittest.TestCase): setupCalled = False diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 8f419a5e8446a..d8aa5f93182e2 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -14,10 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -from __future__ import print_function -import json - from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py index 085fce6daa4ec..e85cae7dda2c6 100644 --- a/python/pyspark/testing/sqlutils.py +++ b/python/pyspark/testing/sqlutils.py @@ -24,7 +24,6 @@ from pyspark.sql import SparkSession from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row from pyspark.testing.utils import ReusedPySparkTestCase -from pyspark.util import _exception_message pandas_requirement_message = None @@ -33,7 +32,7 @@ require_minimum_pandas_version() except ImportError as e: # If Pandas version requirement is not satisfied, skip related tests. - pandas_requirement_message = _exception_message(e) + pandas_requirement_message = str(e) pyarrow_requirement_message = None try: @@ -41,14 +40,14 @@ require_minimum_pyarrow_version() except ImportError as e: # If Arrow version requirement is not satisfied, skip related tests. - pyarrow_requirement_message = _exception_message(e) + pyarrow_requirement_message = str(e) test_not_compiled_message = None try: from pyspark.sql.utils import require_test_compiled require_test_compiled() except Exception as e: - test_not_compiled_message = _exception_message(e) + test_not_compiled_message = str(e) have_pandas = pandas_requirement_message is None have_pyarrow = pyarrow_requirement_message is None diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py index 5833bf9f96fb3..168299e385e78 100644 --- a/python/pyspark/tests/test_context.py +++ b/python/pyspark/tests/test_context.py @@ -267,6 +267,14 @@ def test_resources(self): resources = sc.resources self.assertEqual(len(resources), 0) + def test_disallow_to_create_spark_context_in_executors(self): + # SPARK-32160: SparkContext should not be created in executors. + with SparkContext("local-cluster[3, 1, 1024]") as sc: + with self.assertRaises(Exception) as context: + sc.range(2).foreach(lambda _: SparkContext()) + self.assertIn("SparkContext should only be created and accessed on the driver.", + str(context.exception)) + class ContextTestsWithResources(unittest.TestCase): diff --git a/python/pyspark/tests/test_profiler.py b/python/pyspark/tests/test_profiler.py index 04ca5a3896bf4..dbce72a0d3489 100644 --- a/python/pyspark/tests/test_profiler.py +++ b/python/pyspark/tests/test_profiler.py @@ -19,15 +19,11 @@ import sys import tempfile import unittest +from io import StringIO from pyspark import SparkConf, SparkContext, BasicProfiler from pyspark.testing.utils import PySparkTestCase -if sys.version >= "3": - from io import StringIO -else: - from StringIO import StringIO - class ProfilerTests(PySparkTestCase): diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index 6c5b818056f2d..1a580e27ea527 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -32,9 +32,6 @@ MarshalSerializer, UTF8Deserializer, NoOpSerializer from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME, QuietTest -if sys.version_info[0] >= 3: - xrange = range - global_func = lambda: "Hi" @@ -193,15 +190,13 @@ def test_deleting_input_files(self): def test_sampling_default_seed(self): # Test for SPARK-3995 (default seed setting) - data = self.sc.parallelize(xrange(1000), 1) + data = self.sc.parallelize(range(1000), 1) subset = data.takeSample(False, 10) self.assertEqual(len(subset), 10) def test_aggregate_mutable_zero_value(self): # Test for SPARK-9021; uses aggregate and treeAggregate to build dict # representing a counter of ints - # NOTE: dict is used instead of collections.Counter for Python 2.6 - # compatibility from collections import defaultdict # Show that single or multiple partitions work @@ -262,8 +257,6 @@ def comboOp(x, y): def test_fold_mutable_zero_value(self): # Test for SPARK-9021; uses fold to merge an RDD of dict counters into # a single dict - # NOTE: dict is used instead of collections.Counter for Python 2.6 - # compatibility from collections import defaultdict counts1 = defaultdict(int, dict((i, 1) for i in range(10))) @@ -439,7 +432,7 @@ def run(f, sc): def test_large_closure(self): N = 200000 - data = [float(i) for i in xrange(N)] + data = [float(i) for i in range(N)] rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) self.assertEqual(N, rdd.first()) # regression test for SPARK-6886 @@ -464,8 +457,8 @@ def test_zip_with_different_serializers(self): def test_zip_with_different_object_sizes(self): # regress test for SPARK-5973 - a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i) - b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i) + a = self.sc.parallelize(range(10000)).map(lambda i: '*' * i) + b = self.sc.parallelize(range(10000, 20000)).map(lambda i: '*' * i) self.assertEqual(10000, a.zip(b).count()) def test_zip_with_different_number_of_items(self): @@ -487,7 +480,7 @@ def test_zip_with_different_number_of_items(self): self.assertRaises(Exception, lambda: a.zip(b).count()) def test_count_approx_distinct(self): - rdd = self.sc.parallelize(xrange(1000)) + rdd = self.sc.parallelize(range(1000)) self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050) self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050) self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050) @@ -641,7 +634,7 @@ def test_distinct(self): def test_external_group_by_key(self): self.sc._conf.set("spark.python.worker.memory", "1m") N = 2000001 - kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x)) + kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x)) gkv = kv.groupByKey().cache() self.assertEqual(3, gkv.count()) filtered = gkv.filter(lambda kv: kv[0] == 1) @@ -698,7 +691,7 @@ def test_multiple_python_java_RDD_conversions(self): # Regression test for SPARK-6294 def test_take_on_jrdd(self): - rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x)) + rdd = self.sc.parallelize(range(1 << 20)).map(lambda x: str(x)) rdd._jrdd.first() def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): diff --git a/python/pyspark/tests/test_readwrite.py b/python/pyspark/tests/test_readwrite.py index 734b7e4789f61..faa006c7d82e5 100644 --- a/python/pyspark/tests/test_readwrite.py +++ b/python/pyspark/tests/test_readwrite.py @@ -38,104 +38,6 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name) - @unittest.skipIf(sys.version >= "3", "serialize array of byte") - def test_sequencefiles(self): - basepath = self.tempdir.name - ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.assertEqual(ints, ei) - - doubles = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfdouble/", - "org.apache.hadoop.io.DoubleWritable", - "org.apache.hadoop.io.Text").collect()) - ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] - self.assertEqual(doubles, ed) - - bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BytesWritable").collect()) - ebs = [(1, bytearray('aa', 'utf-8')), - (1, bytearray('aa', 'utf-8')), - (2, bytearray('aa', 'utf-8')), - (2, bytearray('bb', 'utf-8')), - (2, bytearray('bb', 'utf-8')), - (3, bytearray('cc', 'utf-8'))] - self.assertEqual(bytes, ebs) - - text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/", - "org.apache.hadoop.io.Text", - "org.apache.hadoop.io.Text").collect()) - et = [(u'1', u'aa'), - (u'1', u'aa'), - (u'2', u'aa'), - (u'2', u'bb'), - (u'2', u'bb'), - (u'3', u'cc')] - self.assertEqual(text, et) - - bools = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbool/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BooleanWritable").collect()) - eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] - self.assertEqual(bools, eb) - - nulls = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfnull/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BooleanWritable").collect()) - en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] - self.assertEqual(nulls, en) - - maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect() - em = [(1, {}), - (1, {3.0: u'bb'}), - (2, {1.0: u'aa'}), - (2, {1.0: u'cc'}), - (3, {2.0: u'dd'})] - for v in maps: - self.assertTrue(v in em) - - # arrays get pickled to tuples by default - tuples = sorted(self.sc.sequenceFile( - basepath + "/sftestdata/sfarray/", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable").collect()) - et = [(1, ()), - (2, (3.0, 4.0, 5.0)), - (3, (4.0, 5.0, 6.0))] - self.assertEqual(tuples, et) - - # with custom converters, primitive arrays can stay as arrays - arrays = sorted(self.sc.sequenceFile( - basepath + "/sftestdata/sfarray/", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) - ea = [(1, array('d')), - (2, array('d', [3.0, 4.0, 5.0])), - (3, array('d', [4.0, 5.0, 6.0]))] - self.assertEqual(arrays, ea) - - clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", - "org.apache.hadoop.io.Text", - "org.apache.spark.api.python.TestWritable").collect()) - cname = u'org.apache.spark.api.python.TestWritable' - ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}), - (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}), - (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}), - (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}), - (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})] - self.assertEqual(clazz, ec) - - unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", - "org.apache.hadoop.io.Text", - "org.apache.spark.api.python.TestWritable", - ).collect()) - self.assertEqual(unbatched_clazz, ec) - def test_oldhadoop(self): basepath = self.tempdir.name ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/", @@ -249,51 +151,6 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.tempdir.name, ignore_errors=True) - @unittest.skipIf(sys.version >= "3", "serialize array of byte") - def test_sequencefiles(self): - basepath = self.tempdir.name - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/") - ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect()) - self.assertEqual(ints, ei) - - ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] - self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/") - doubles = sorted(self.sc.sequenceFile(basepath + "/sfdouble/").collect()) - self.assertEqual(doubles, ed) - - ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, bytearray(b'\x00\x07spam\x08'))] - self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/") - bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect()) - self.assertEqual(bytes, ebs) - - et = [(u'1', u'aa'), - (u'2', u'bb'), - (u'3', u'cc')] - self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/") - text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect()) - self.assertEqual(text, et) - - eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] - self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/") - bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect()) - self.assertEqual(bools, eb) - - en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] - self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/") - nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect()) - self.assertEqual(nulls, en) - - em = [(1, {}), - (1, {3.0: u'bb'}), - (2, {1.0: u'aa'}), - (2, {1.0: u'cc'}), - (3, {2.0: u'dd'})] - self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") - maps = self.sc.sequenceFile(basepath + "/sfmap/").collect() - for v in maps: - self.assertTrue(v, em) - def test_oldhadoop(self): basepath = self.tempdir.name dict_data = [(1, {}), @@ -361,46 +218,6 @@ def test_newhadoop(self): conf=input_conf).collect()) self.assertEqual(new_dataset, data) - @unittest.skipIf(sys.version >= "3", "serialize of array") - def test_newhadoop_with_array(self): - basepath = self.tempdir.name - # use custom ArrayWritable types and converters to handle arrays - array_data = [(1, array('d')), - (1, array('d', [1.0, 2.0, 3.0])), - (2, array('d', [3.0, 4.0, 5.0]))] - self.sc.parallelize(array_data).saveAsNewAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") - result = sorted(self.sc.newAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) - self.assertEqual(result, array_data) - - conf = { - "mapreduce.job.outputformat.class": - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" - } - self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset( - conf, - valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") - input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} - new_dataset = sorted(self.sc.newAPIHadoopRDD( - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter", - conf=input_conf).collect()) - self.assertEqual(new_dataset, array_data) - def test_newolderror(self): basepath = self.tempdir.name rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) diff --git a/python/pyspark/tests/test_shuffle.py b/python/pyspark/tests/test_shuffle.py index d50ba632d6cd4..434414618e59d 100644 --- a/python/pyspark/tests/test_shuffle.py +++ b/python/pyspark/tests/test_shuffle.py @@ -23,15 +23,12 @@ from pyspark import shuffle, PickleSerializer, SparkConf, SparkContext from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter -if sys.version_info[0] >= 3: - xrange = range - class MergerTests(unittest.TestCase): def setUp(self): self.N = 1 << 12 - self.l = [i for i in xrange(self.N)] + self.l = [i for i in range(self.N)] self.data = list(zip(self.l, self.l)) self.agg = Aggregator(lambda x: [x], lambda x, y: x.append(y) or x, @@ -42,26 +39,26 @@ def test_small_dataset(self): m.mergeValues(self.data) self.assertEqual(m.spills, 0) self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) + sum(range(self.N))) m = ExternalMerger(self.agg, 1000) m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data)) self.assertEqual(m.spills, 0) self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) + sum(range(self.N))) def test_medium_dataset(self): m = ExternalMerger(self.agg, 20) m.mergeValues(self.data) self.assertTrue(m.spills >= 1) self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) + sum(range(self.N))) m = ExternalMerger(self.agg, 10) m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3)) self.assertTrue(m.spills >= 1) self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N)) * 3) + sum(range(self.N)) * 3) def test_huge_dataset(self): m = ExternalMerger(self.agg, 5, partitions=3) diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py index 90e4bcdfadc03..8c2bedbe4e212 100644 --- a/python/pyspark/tests/test_taskcontext.py +++ b/python/pyspark/tests/test_taskcontext.py @@ -26,9 +26,6 @@ from pyspark import SparkConf, SparkContext, TaskContext, BarrierTaskContext from pyspark.testing.utils import PySparkTestCase, SPARK_HOME -if sys.version_info[0] >= 3: - xrange = range - class TaskContextTests(PySparkTestCase): @@ -251,9 +248,9 @@ def context_barrier(x): def test_task_context_correct_with_python_worker_reuse(self): """Verify the task context correct when reused python worker""" # start a normal job first to start all workers and get all worker pids - worker_pids = self.sc.parallelize(xrange(2), 2).map(lambda x: os.getpid()).collect() + worker_pids = self.sc.parallelize(range(2), 2).map(lambda x: os.getpid()).collect() # the worker will reuse in this barrier job - rdd = self.sc.parallelize(xrange(10), 2) + rdd = self.sc.parallelize(range(10), 2) def context(iterator): tp = TaskContext.get().partitionId() diff --git a/python/pyspark/tests/test_util.py b/python/pyspark/tests/test_util.py index 81bfb66e7019d..511d62a51f3df 100644 --- a/python/pyspark/tests/test_util.py +++ b/python/pyspark/tests/test_util.py @@ -61,14 +61,12 @@ def set(self, x=None, other=None, other_x=None): class UtilTests(PySparkTestCase): - def test_py4j_exception_message(self): - from pyspark.util import _exception_message - + def test_py4j_str(self): with self.assertRaises(Py4JJavaError) as context: # This attempts java.lang.String(null) which throws an NPE. self.sc._jvm.java.lang.String(None) - self.assertTrue('NullPointerException' in _exception_message(context.exception)) + self.assertTrue('NullPointerException' in str(context.exception)) def test_parsing_version_string(self): from pyspark.util import VersionUtils diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index dba9298ee161a..3b1848dcfdee9 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -32,9 +32,6 @@ from pyspark import SparkConf, SparkContext from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest -if sys.version_info[0] >= 3: - xrange = range - class WorkerTests(ReusedPySparkTestCase): def test_cancel_task(self): @@ -88,13 +85,13 @@ def run(): self.fail("daemon had been killed") # run a normal job - rdd = self.sc.parallelize(xrange(100), 1) + rdd = self.sc.parallelize(range(100), 1) self.assertEqual(100, rdd.map(str).count()) def test_after_exception(self): def raise_exception(_): raise Exception() - rdd = self.sc.parallelize(xrange(100), 1) + rdd = self.sc.parallelize(range(100), 1) with QuietTest(self.sc): self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) self.assertEqual(100, rdd.map(str).count()) @@ -110,22 +107,22 @@ def test_after_jvm_exception(self): with QuietTest(self.sc): self.assertRaises(Exception, lambda: filtered_data.count()) - rdd = self.sc.parallelize(xrange(100), 1) + rdd = self.sc.parallelize(range(100), 1) self.assertEqual(100, rdd.map(str).count()) def test_accumulator_when_reuse_worker(self): from pyspark.accumulators import INT_ACCUMULATOR_PARAM acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) - self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x)) + self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x)) self.assertEqual(sum(range(100)), acc1.value) acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) - self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x)) + self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x)) self.assertEqual(sum(range(100)), acc2.value) self.assertEqual(sum(range(100)), acc1.value) def test_reuse_worker_after_take(self): - rdd = self.sc.parallelize(xrange(100000), 1) + rdd = self.sc.parallelize(range(100000), 1) self.assertEqual(0, rdd.first()) def count(): @@ -160,17 +157,13 @@ def f(): self.sc.parallelize([1]).map(lambda x: f()).count() except Py4JJavaError as e: - if sys.version_info.major < 3: - # we have to use unicode here to avoid UnicodeDecodeError - self.assertRegexpMatches(unicode(e).encode("utf-8"), "exception with 中") - else: - self.assertRegexpMatches(str(e), "exception with 中") + self.assertRegexpMatches(str(e), "exception with 中") class WorkerReuseTest(PySparkTestCase): - def test_reuse_worker_of_parallelize_xrange(self): - rdd = self.sc.parallelize(xrange(20), 8) + def test_reuse_worker_of_parallelize_range(self): + rdd = self.sc.parallelize(range(20), 8) previous_pids = rdd.map(lambda x: os.getpid()).collect() current_pids = rdd.map(lambda x: os.getpid()).collect() for pid in current_pids: @@ -189,7 +182,7 @@ def setUp(self): self.sc = SparkContext('local[4]', class_name, conf=conf) def test_memory_limit(self): - rdd = self.sc.parallelize(xrange(1), 1) + rdd = self.sc.parallelize(range(1), 1) def getrlimit(): import resource diff --git a/python/pyspark/util.py b/python/pyspark/util.py index d9429372a6bfc..c003586e9c03b 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -19,52 +19,10 @@ import re import sys import traceback -import os -import warnings -import inspect -from py4j.protocol import Py4JJavaError __all__ = [] -def _exception_message(excp): - """Return the message from an exception as either a str or unicode object. Supports both - Python 2 and Python 3. - - >>> msg = "Exception message" - >>> excp = Exception(msg) - >>> msg == _exception_message(excp) - True - - >>> msg = u"unicöde" - >>> excp = Exception(msg) - >>> msg == _exception_message(excp) - True - """ - if isinstance(excp, Py4JJavaError): - # 'Py4JJavaError' doesn't contain the stack trace available on the Java side in 'message' - # attribute in Python 2. We should call 'str' function on this exception in general but - # 'Py4JJavaError' has an issue about addressing non-ascii strings. So, here we work - # around by the direct call, '__str__()'. Please see SPARK-23517. - return excp.__str__() - if hasattr(excp, "message"): - return excp.message - return str(excp) - - -def _get_argspec(f): - """ - Get argspec of a function. Supports both Python 2 and Python 3. - """ - if sys.version_info[0] < 3: - argspec = inspect.getargspec(f) - else: - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. - argspec = inspect.getfullargspec(f) - return argspec - - def print_exec(stream): ei = sys.exc_info() traceback.print_exception(ei[0], ei[1], ei[2], None, stream) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 5f4a8a2d2db1f..9b54affb137f5 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -18,11 +18,11 @@ """ Worker that receives input from Piped RDD. """ -from __future__ import print_function -from __future__ import absolute_import import os import sys import time +from inspect import getfullargspec +import importlib # 'resource' is a Unix specific module. has_resource_module = True try: @@ -44,14 +44,9 @@ from pyspark.sql.pandas.serializers import ArrowStreamPandasUDFSerializer, CogroupUDFSerializer from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import StructType -from pyspark.util import _get_argspec, fail_on_stopiteration +from pyspark.util import fail_on_stopiteration from pyspark import shuffle -if sys.version >= '3': - basestring = str -else: - from itertools import imap as map # use iterator map by default - pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -272,10 +267,10 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: return arg_offsets, wrap_pandas_iter_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: - argspec = _get_argspec(chained_func) # signature was lost when wrapping it + argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: - argspec = _get_argspec(chained_func) # signature was lost when wrapping it + argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) @@ -342,11 +337,13 @@ def read_udfs(pickleSer, infile, eval_type): pickleSer, infile, eval_type, runner_conf, udf_index=0) def func(_, iterator): - num_input_rows = [0] # TODO(SPARK-29909): Use nonlocal after we drop Python 2. + num_input_rows = 0 def map_batch(batch): + nonlocal num_input_rows + udf_args = [batch[offset] for offset in arg_offsets] - num_input_rows[0] += len(udf_args[0]) + num_input_rows += len(udf_args[0]) if len(udf_args) == 1: return udf_args[0] else: @@ -363,7 +360,7 @@ def map_batch(batch): # by consuming the input iterator in user side. Therefore, # it's very unlikely the output length is higher than # input length. - assert is_map_iter or num_output_rows <= num_input_rows[0], \ + assert is_map_iter or num_output_rows <= num_input_rows, \ "Pandas SCALAR_ITER UDF outputted more rows than input rows." yield (result_batch, result_type) @@ -376,11 +373,11 @@ def map_batch(batch): raise RuntimeError("pandas iterator UDF should exhaust the input " "iterator.") - if num_output_rows != num_input_rows[0]: + if num_output_rows != num_input_rows: raise RuntimeError( "The length of output in Scalar iterator pandas UDF should be " "the same with the input's; however, the length of output was %d and the " - "length of input was %d." % (num_output_rows, num_input_rows[0])) + "length of input was %d." % (num_output_rows, num_input_rows)) # profiling is not supported for UDF return func, None, ser, ser @@ -548,9 +545,8 @@ def main(infile, outfile): for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) - if sys.version > '3': - import importlib - importlib.invalidate_caches() + + importlib.invalidate_caches() # fetch names and values of broadcast variables needs_broadcast_decryption_server = read_bool(infile) diff --git a/python/run-tests.py b/python/run-tests.py index b677a5134ec93..db62f964791ac 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -28,10 +28,7 @@ from threading import Thread, Lock import time import uuid -if sys.version < '3': - import Queue -else: - import queue as Queue +import queue as Queue from multiprocessing import Manager @@ -75,7 +72,6 @@ def run_individual_python_test(target_dir, test_name, pyspark_python): 'SPARK_PREPEND_CLASSES': '1', 'PYSPARK_PYTHON': which(pyspark_python), 'PYSPARK_DRIVER_PYTHON': which(pyspark_python), - 'PYSPARK_ROW_FIELD_SORTING_ENABLED': 'true' }) # Create a unique temp directory under 'target/' for each run. The TMPDIR variable is @@ -161,7 +157,7 @@ def run_individual_python_test(target_dir, test_name, pyspark_python): def get_default_python_executables(): - python_execs = [x for x in ["python3.6", "python2.7", "pypy"] if which(x)] + python_execs = [x for x in ["python3.6", "python3.8", "pypy3"] if which(x)] if "python3.6" not in python_execs: p = which("python3") @@ -272,7 +268,7 @@ def main(): [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip()) if should_test_modules: for module in modules_to_test: - if python_implementation not in module.blacklisted_python_implementations: + if python_implementation not in module.excluded_python_implementations: for test_goal in module.python_test_goals: heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests', 'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests'] diff --git a/python/setup.py b/python/setup.py index afbd601b04a94..c456a32fea87c 100755 --- a/python/setup.py +++ b/python/setup.py @@ -16,18 +16,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function import glob import os import sys from setuptools import setup from shutil import copyfile, copytree, rmtree -if sys.version_info < (2, 7): - print("Python versions prior to 2.7 are not supported for pip installed PySpark.", - file=sys.stderr) - sys.exit(-1) - try: exec(open('pyspark/version.py').read()) except IOError: @@ -217,13 +211,10 @@ def _supports_symlinks(): 'pyarrow>=%s' % _minimum_pyarrow_version, ] }, + python_requires='>=3.6', classifiers=[ 'Development Status :: 5 - Production/Stable', 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', diff --git a/resource-managers/kubernetes/integration-tests/tests/pyfiles.py b/resource-managers/kubernetes/integration-tests/tests/pyfiles.py index ba55b75803276..51c0160554866 100644 --- a/resource-managers/kubernetes/integration-tests/tests/pyfiles.py +++ b/resource-managers/kubernetes/integration-tests/tests/pyfiles.py @@ -14,9 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -from __future__ import print_function - import sys from pyspark.sql import SparkSession diff --git a/resource-managers/kubernetes/integration-tests/tests/worker_memory_check.py b/resource-managers/kubernetes/integration-tests/tests/worker_memory_check.py index d312a29f388e4..74559a0b54402 100644 --- a/resource-managers/kubernetes/integration-tests/tests/worker_memory_check.py +++ b/resource-managers/kubernetes/integration-tests/tests/worker_memory_check.py @@ -15,8 +15,6 @@ # limitations under the License. # -from __future__ import print_function - import resource import sys diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index 8dc123e93fe16..b8c64a28c72cd 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -92,8 +92,8 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") private def launchedRow(submissionState: Option[MesosClusterSubmissionState]): Seq[Node] = { submissionState.map { state =>
- - + + diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 173a9b86e7de6..772906397546c 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -41,7 +41,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( val queuedHeaders = driverHeader ++ submissionHeader val driverHeaders = driverHeader ++ historyHeader ++ submissionHeader ++ - Seq("Start Date", "Mesos Slave ID", "State") ++ sandboxHeader + Seq("Start Date", "Mesos Agent ID", "State") ++ sandboxHeader val retryHeaders = Seq("Driver ID", "Submit Date", "Description") ++ Seq("Last Failed Status", "Next Retry Time", "Attempt Count") val queuedTable = UIUtils.listingTable(queuedHeaders, queuedRow, state.queuedDrivers) @@ -81,7 +81,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( val sandboxCol = if (proxy.isDefined) { val clusterSchedulerId = parent.scheduler.getSchedulerState().frameworkId - val sandBoxUri = s"${proxy.get}/#/agents/${state.slaveId.getValue}/frameworks/" + + val sandBoxUri = s"${proxy.get}/#/agents/${state.agentId.getValue}/frameworks/" + s"${clusterSchedulerId}/executors/${id}/browse" Sandbox } else { @@ -103,7 +103,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( - + diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index 47243e83d1335..b023cf1fa4bb2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -26,18 +26,16 @@ import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} import org.apache.mesos.protobuf.ByteString import org.apache.spark.{SparkConf, SparkEnv, TaskState} -import org.apache.spark.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config.EXECUTOR_ID import org.apache.spark.resource.ResourceInformation import org.apache.spark.scheduler.TaskDescription -import org.apache.spark.scheduler.cluster.mesos.MesosSchedulerUtils +import org.apache.spark.scheduler.cluster.mesos.MesosSchedulerBackendUtil import org.apache.spark.util.Utils private[spark] class MesosExecutorBackend extends MesosExecutor - with MesosSchedulerUtils // TODO: fix with ExecutorBackend with Logging { @@ -48,7 +46,7 @@ private[spark] class MesosExecutorBackend val mesosTaskId = TaskID.newBuilder().setValue(taskId.toString).build() driver.sendStatusUpdate(MesosTaskStatus.newBuilder() .setTaskId(mesosTaskId) - .setState(taskStateToMesos(state)) + .setState(MesosSchedulerBackendUtil.taskStateToMesos(state)) .setData(ByteString.copyFrom(data)) .build()) } @@ -57,7 +55,7 @@ private[spark] class MesosExecutorBackend driver: ExecutorDriver, executorInfo: ExecutorInfo, frameworkInfo: FrameworkInfo, - slaveInfo: SlaveInfo): Unit = { + agentInfo: SlaveInfo): Unit = { // Get num cores for this task from ExecutorInfo, created in MesosSchedulerBackend. val cpusPerTask = executorInfo.getResourcesList.asScala @@ -78,11 +76,11 @@ private[spark] class MesosExecutorBackend val conf = new SparkConf(loadDefaults = true).setAll(properties) conf.set(EXECUTOR_ID, executorId) val env = SparkEnv.createExecutorEnv( - conf, executorId, slaveInfo.getHostname, cpusPerTask, None, isLocal = false) + conf, executorId, agentInfo.getHostname, cpusPerTask, None, isLocal = false) executor = new Executor( executorId, - slaveInfo.getHostname, + agentInfo.getHostname, env, resources = Map.empty[String, ResourceInformation]) } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 289b109a42747..26939ef23eaab 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -25,7 +25,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.mesos.{Scheduler, SchedulerDriver} -import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} +import org.apache.mesos.Protos.{SlaveID => AgentID, TaskState => MesosTaskState, _} import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason @@ -41,7 +41,7 @@ import org.apache.spark.util.Utils * @param driverDescription Submitted driver description from * [[org.apache.spark.deploy.rest.mesos.MesosRestServer]] * @param taskId Mesos TaskID generated for the task - * @param slaveId Slave ID that the task is assigned to + * @param agentId Agent ID that the task is assigned to * @param mesosTaskStatus The last known task status update. * @param startDate The date the task was launched * @param finishDate The date the task finished @@ -50,7 +50,7 @@ import org.apache.spark.util.Utils private[spark] class MesosClusterSubmissionState( val driverDescription: MesosDriverDescription, val taskId: TaskID, - val slaveId: SlaveID, + val agentId: AgentID, var mesosTaskStatus: Option[TaskStatus], var startDate: Date, var finishDate: Option[Date], @@ -59,7 +59,7 @@ private[spark] class MesosClusterSubmissionState( def copy(): MesosClusterSubmissionState = { new MesosClusterSubmissionState( - driverDescription, taskId, slaveId, mesosTaskStatus, startDate, finishDate, frameworkId) + driverDescription, taskId, agentId, mesosTaskStatus, startDate, finishDate, frameworkId) } } @@ -113,7 +113,7 @@ private[spark] class MesosDriverState( * A Mesos scheduler that is responsible for launching submitted Spark drivers in cluster mode * as Mesos tasks in a Mesos cluster. * All drivers are launched asynchronously by the framework, which will eventually be launched - * by one of the slaves in the cluster. The results of the driver will be stored in slave's task + * by one of the agents in the cluster. The results of the driver will be stored in agent's task * sandbox which is accessible by visiting the Mesos UI. * This scheduler supports recovery by persisting all its state and performs task reconciliation * on recover, which gets all the latest state for all the drivers from Mesos master. @@ -121,7 +121,7 @@ private[spark] class MesosDriverState( private[spark] class MesosClusterScheduler( engineFactory: MesosClusterPersistenceEngineFactory, conf: SparkConf) - extends Scheduler with MesosSchedulerUtils { + extends Scheduler with MesosSchedulerUtils with MesosScheduler { var frameworkUrl: String = _ private val metricsSystem = MetricsSystem.createMetricsSystem(MetricsSystemInstances.MESOS_CLUSTER, conf, @@ -139,10 +139,10 @@ private[spark] class MesosClusterScheduler( private var frameworkId: String = null // Holds all the launched drivers and current launch state, keyed by submission id. private val launchedDrivers = new mutable.HashMap[String, MesosClusterSubmissionState]() - // Holds a map of driver id to expected slave id that is passed to Mesos for reconciliation. + // Holds a map of driver id to expected agent id that is passed to Mesos for reconciliation. // All drivers that are loaded after failover are added here, as we need get the latest // state of the tasks from Mesos. Keyed by task Id. - private val pendingRecover = new mutable.HashMap[String, SlaveID]() + private val pendingRecover = new mutable.HashMap[String, AgentID]() // Stores all the submitted drivers that hasn't been launched, keyed by submission id private val queuedDrivers = new ArrayBuffer[MesosDriverDescription]() // All supervised drivers that are waiting to retry after termination, keyed by submission id @@ -277,7 +277,7 @@ private[spark] class MesosClusterScheduler( stateLock.synchronized { launchedDriversState.fetchAll[MesosClusterSubmissionState]().foreach { state => launchedDrivers(state.driverDescription.submissionId) = state - pendingRecover(state.taskId.getValue) = state.slaveId + pendingRecover(state.taskId.getValue) = state.agentId } queuedDriversState.fetchAll[MesosDriverDescription]().foreach(d => queuedDrivers += d) // There is potential timing issue where a queued driver might have been launched @@ -348,10 +348,10 @@ private[spark] class MesosClusterScheduler( if (!pendingRecover.isEmpty) { // Start task reconciliation if we need to recover. val statuses = pendingRecover.collect { - case (taskId, slaveId) => + case (taskId, agentId) => val newStatus = TaskStatus.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId).build()) - .setSlaveId(slaveId) + .setSlaveId(agentId) .setState(MesosTaskState.TASK_STAGING) .build() launchedDrivers.get(getSubmissionIdFromTaskId(taskId)) @@ -539,14 +539,14 @@ private[spark] class MesosClusterScheduler( options ++= Seq("--py-files", formattedFiles) // --conf - val replicatedOptionsBlacklist = Set( + val replicatedOptionsExcludeList = Set( JARS.key, // Avoids duplicate classes in classpath SUBMIT_DEPLOY_MODE.key, // this would be set to `cluster`, but we need client "spark.master" // this contains the address of the dispatcher, not master ) val defaultConf = conf.getAllWithPrefix(config.DISPATCHER_DRIVER_DEFAULT_PREFIX).toMap val driverConf = desc.conf.getAll - .filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) } + .filter { case (key, _) => !replicatedOptionsExcludeList.contains(key) } .toMap (defaultConf ++ driverConf).foreach { case (key, value) => options ++= Seq("--conf", s"${key}=${value}") } @@ -657,7 +657,7 @@ private[spark] class MesosClusterScheduler( finishedDrivers += new MesosClusterSubmissionState( submission, TaskID.newBuilder().setValue(submission.submissionId).build(), - SlaveID.newBuilder().setValue("").build(), + AgentID.newBuilder().setValue("").build(), None, null, None, @@ -731,7 +731,7 @@ private[spark] class MesosClusterScheduler( override def reregistered(driver: SchedulerDriver, masterInfo: MasterInfo): Unit = { logInfo(s"Framework re-registered with master ${masterInfo.getId}") } - override def slaveLost(driver: SchedulerDriver, slaveId: SlaveID): Unit = {} + override def agentLost(driver: SchedulerDriver, agentId: AgentID): Unit = {} override def error(driver: SchedulerDriver, error: String): Unit = { logError("Error received: " + error) markErr() @@ -815,13 +815,13 @@ private[spark] class MesosClusterScheduler( override def frameworkMessage( driver: SchedulerDriver, executorId: ExecutorID, - slaveId: SlaveID, + agentId: AgentID, message: Array[Byte]): Unit = {} override def executorLost( driver: SchedulerDriver, executorId: ExecutorID, - slaveId: SlaveID, + agentId: AgentID, status: Int): Unit = {} private def removeFromQueuedDrivers(subId: String): Boolean = { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 0b447025c8a7a..5e7a29ac6d344 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -27,7 +27,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.concurrent.Future -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.Protos.{SlaveID => AgentID, TaskInfo => MesosTaskInfo, _} import org.apache.mesos.SchedulerDriver import org.apache.spark.{SecurityManager, SparkConf, SparkContext, SparkException, TaskState} @@ -40,7 +40,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalBlockStoreClient import org.apache.spark.resource.ResourceProfile import org.apache.spark.rpc.{RpcEndpointAddress, RpcEndpointRef} -import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} +import org.apache.spark.scheduler.{ExecutorProcessLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -60,10 +60,11 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( master: String, securityManager: SecurityManager) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) - with org.apache.mesos.Scheduler with MesosSchedulerUtils { + with MesosScheduler + with MesosSchedulerUtils { - // Blacklist a slave after this many failures - private val MAX_SLAVE_FAILURES = 2 + // Blacklist a agent after this many failures + private val MAX_AGENT_FAILURES = 2 private val maxCoresOption = conf.get(config.CORES_MAX) @@ -116,10 +117,10 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // executor limit private var launchingExecutors = false - // SlaveID -> Slave - // This map accumulates entries for the duration of the job. Slaves are never deleted, because + // AgentID -> Agent + // This map accumulates entries for the duration of the job. Agents are never deleted, because // we need to maintain e.g. failure state and connection state. - private val slaves = new mutable.HashMap[String, Slave] + private val agents = new mutable.HashMap[String, Agent] /** * The total number of executors we aim to have. Undefined when not using dynamic allocation. @@ -147,7 +148,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private val extraCoresPerExecutor = conf.get(EXTRA_CORES_PER_EXECUTOR) // Offer constraints - private val slaveOfferConstraints = + private val agentOfferConstraints = parseConstraintString(sc.conf.get(CONSTRAINTS)) // Reject offers with mismatched constraints in seconds @@ -354,7 +355,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } /** - * Method called by Mesos to offer resources on slaves. We respond by launching an executor, + * Method called by Mesos to offer resources on agents. We respond by launching an executor, * unless we've already launched more than we wanted to. */ override def resourceOffers(d: org.apache.mesos.SchedulerDriver, offers: JList[Offer]): Unit = { @@ -384,7 +385,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( val (matchedOffers, unmatchedOffers) = offers.asScala.partition { offer => val offerAttributes = toAttributeMap(offer.getAttributesList) - matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + matchesAttributeRequirements(agentOfferConstraints, offerAttributes) } declineUnmatchedOffers(d, unmatchedOffers) @@ -441,7 +442,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( val ports = getRangeResource(task.getResourcesList, "ports").mkString(",") logDebug(s"Launching Mesos task: ${taskId.getValue} with mem: $mem cpu: $cpus" + - s" ports: $ports" + s" on slave with slave id: ${task.getSlaveId.getValue} ") + s" ports: $ports" + s" on agent with agent id: ${task.getSlaveId.getValue} ") } driver.launchTasks( @@ -495,18 +496,18 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( var launchTasks = true - // TODO(mgummelt): combine offers for a single slave + // TODO(mgummelt): combine offers for a single agent // // round-robin create executors on the available offers while (launchTasks) { launchTasks = false for (offer <- offers) { - val slaveId = offer.getSlaveId.getValue + val agentId = offer.getSlaveId.getValue val offerId = offer.getId.getValue val resources = remainingResources(offerId) - if (canLaunchTask(slaveId, offer.getHostname, resources)) { + if (canLaunchTask(agentId, offer.getHostname, resources)) { // Create a task launchTasks = true val taskId = newMesosTaskId() @@ -517,7 +518,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( val taskCPUs = executorCores(offerCPUs) val taskMemory = executorMemory(sc) - slaves.getOrElseUpdate(slaveId, new Slave(offer.getHostname)).taskIDs.add(taskId) + agents.getOrElseUpdate(agentId, new Agent(offer.getHostname)).taskIDs.add(taskId) val (resourcesLeft, resourcesToUse) = partitionTaskResources(resources, taskCPUs, taskMemory, taskGPUs) @@ -540,8 +541,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( gpusByTaskId(taskId) = taskGPUs } } else { - logDebug(s"Cannot launch a task for offer with id: $offerId on slave " + - s"with id: $slaveId. Requirements were not met for this offer.") + logDebug(s"Cannot launch a task for offer with id: $offerId on agent " + + s"with id: $agentId. Requirements were not met for this offer.") } } } @@ -573,7 +574,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( cpuResourcesToUse ++ memResourcesToUse ++ portResourcesToUse ++ gpuResourcesToUse) } - private def canLaunchTask(slaveId: String, offerHostname: String, + private def canLaunchTask(agentId: String, offerHostname: String, resources: JList[Resource]): Boolean = { val offerMem = getResource(resources, "mem") val offerCPUs = getResource(resources, "cpus").toInt @@ -587,7 +588,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( cpus + totalCoresAcquired <= maxCores && mem <= offerMem && numExecutors < executorLimit && - slaves.get(slaveId).map(_.taskFailures).getOrElse(0) < MAX_SLAVE_FAILURES && + agents.get(agentId).map(_.taskFailures).getOrElse(0) < MAX_AGENT_FAILURES && meetsPortRequirements && satisfiesLocality(offerHostname) } @@ -606,7 +607,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } // Check the locality information - val currentHosts = slaves.values.filter(_.taskIDs.nonEmpty).map(_.hostname).toSet + val currentHosts = agents.values.filter(_.taskIDs.nonEmpty).map(_.hostname).toSet val allDesiredHosts = hostToLocalTaskCount.map { case (k, v) => k }.toSet // Try to match locality for hosts which do not have executors yet, to potentially @@ -622,13 +623,13 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( override def statusUpdate(d: org.apache.mesos.SchedulerDriver, status: TaskStatus): Unit = { val taskId = status.getTaskId.getValue - val slaveId = status.getSlaveId.getValue + val agentId = status.getSlaveId.getValue val state = mesosToTaskState(status.getState) logInfo(s"Mesos task $taskId is now ${status.getState}") stateLock.synchronized { - val slave = slaves(slaveId) + val agent = agents(agentId) // If the shuffle service is enabled, have the driver register with each one of the // shuffle services. This allows the shuffle services to clean up state associated with @@ -636,23 +637,23 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // this through Mesos, since the shuffle services are set up independently. if (state.equals(TaskState.RUNNING) && shuffleServiceEnabled && - !slave.shuffleRegistered) { + !agent.shuffleRegistered) { assume(mesosExternalShuffleClient.isDefined, "External shuffle client was not instantiated even though shuffle service is enabled.") // TODO: Remove this and allow the MesosExternalShuffleService to detect // framework termination when new Mesos Framework HTTP API is available. val externalShufflePort = conf.get(config.SHUFFLE_SERVICE_PORT) - logDebug(s"Connecting to shuffle service on slave $slaveId, " + - s"host ${slave.hostname}, port $externalShufflePort for app ${conf.getAppId}") + logDebug(s"Connecting to shuffle service on agent $agentId, " + + s"host ${agent.hostname}, port $externalShufflePort for app ${conf.getAppId}") mesosExternalShuffleClient.get .registerDriverWithShuffleService( - slave.hostname, + agent.hostname, externalShufflePort, - sc.conf.get(config.STORAGE_BLOCKMANAGER_SLAVE_TIMEOUT), + sc.conf.get(config.STORAGE_BLOCKMANAGER_HEARTBEAT_TIMEOUT), sc.conf.get(config.EXECUTOR_HEARTBEAT_INTERVAL)) - slave.shuffleRegistered = true + agent.shuffleRegistered = true } if (TaskState.isFinished(state)) { @@ -666,16 +667,16 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( totalGpusAcquired -= gpus gpusByTaskId -= taskId } - // If it was a failure, mark the slave as failed for blacklisting purposes + // If it was a failure, mark the agent as failed for blacklisting purposes if (TaskState.isFailed(state)) { - slave.taskFailures += 1 + agent.taskFailures += 1 - if (slave.taskFailures >= MAX_SLAVE_FAILURES) { - logInfo(s"Blacklisting Mesos slave $slaveId due to too many failures; " + + if (agent.taskFailures >= MAX_AGENT_FAILURES) { + logInfo(s"Blacklisting Mesos agent $agentId due to too many failures; " + "is Spark installed on it?") } } - executorTerminated(d, slaveId, taskId, s"Executor finished with state $state") + executorTerminated(d, agentId, taskId, s"Executor finished with state $state") // In case we'd rejected everything before but have now lost a node d.reviveOffers() } @@ -708,7 +709,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // See SPARK-12330 val startTime = System.nanoTime() - // slaveIdsWithExecutors has no memory barrier, so this is eventually consistent + // agentIdsWithExecutors has no memory barrier, so this is eventually consistent while (numExecutors() > 0 && System.nanoTime() - startTime < shutdownTimeoutMS * 1000L * 1000L) { Thread.sleep(100) @@ -729,15 +730,15 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } override def frameworkMessage( - d: org.apache.mesos.SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]): Unit = {} + d: org.apache.mesos.SchedulerDriver, e: ExecutorID, s: AgentID, b: Array[Byte]): Unit = {} /** - * Called when a slave is lost or a Mesos task finished. Updates local view on + * Called when a agent is lost or a Mesos task finished. Updates local view on * what tasks are running. It also notifies the driver that an executor was removed. */ private def executorTerminated( d: org.apache.mesos.SchedulerDriver, - slaveId: String, + agentId: String, taskId: String, reason: String): Unit = { stateLock.synchronized { @@ -745,18 +746,18 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // removeExecutor() internally will send a message to the driver endpoint but // the driver endpoint is not available now, otherwise an exception will be thrown. if (!stopCalled) { - removeExecutor(taskId, SlaveLost(reason)) + removeExecutor(taskId, ExecutorProcessLost(reason)) } - slaves(slaveId).taskIDs.remove(taskId) + agents(agentId).taskIDs.remove(taskId) } } - override def slaveLost(d: org.apache.mesos.SchedulerDriver, slaveId: SlaveID): Unit = { - logInfo(s"Mesos slave lost: ${slaveId.getValue}") + override def agentLost(d: org.apache.mesos.SchedulerDriver, agentId: AgentID): Unit = { + logInfo(s"Mesos agent lost: ${agentId.getValue}") } override def executorLost( - d: org.apache.mesos.SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = { + d: org.apache.mesos.SchedulerDriver, e: ExecutorID, s: AgentID, status: Int): Unit = { logInfo("Mesos executor lost: %s".format(e.getValue)) } @@ -770,7 +771,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( resourceProfileToTotalExecs: Map[ResourceProfile, Int] ): Future[Boolean] = Future.successful { // We don't truly know if we can fulfill the full amount of executors - // since at coarse grain it depends on the amount of slaves available. + // since at coarse grain it depends on the amount of agents available. val numExecs = resourceProfileToTotalExecs.getOrElse(defaultProfile, 0) logInfo("Capping the total amount of executors to " + numExecs) executorLimitOption = Some(numExecs) @@ -800,11 +801,11 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } private def numExecutors(): Int = { - slaves.values.map(_.taskIDs.size).sum + agents.values.map(_.taskIDs.size).sum } } -private class Slave(val hostname: String) { +private class Agent(val hostname: String) { val taskIDs = new mutable.HashSet[String]() var taskFailures = 0 var shuffleRegistered = false diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index f1e3fcab7e6af..586c2bdd67cfa 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -23,7 +23,8 @@ import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} -import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, SlaveID => AgentID, + TaskInfo => MesosTaskInfo, _} import org.apache.mesos.SchedulerDriver import org.apache.mesos.protobuf.ByteString @@ -46,12 +47,12 @@ private[spark] class MesosFineGrainedSchedulerBackend( sc: SparkContext, master: String) extends SchedulerBackend - with org.apache.mesos.Scheduler + with MesosScheduler with MesosSchedulerUtils { - // Stores the slave ids that has launched a Mesos executor. - val slaveIdToExecutorInfo = new HashMap[String, MesosExecutorInfo] - val taskIdToSlaveId = new HashMap[Long, String] + // Stores the agent ids that has launched a Mesos executor. + val agentIdToExecutorInfo = new HashMap[String, MesosExecutorInfo] + val taskIdToAgentId = new HashMap[Long, String] // An ExecutorInfo for our tasks var execArgs: Array[Byte] = null @@ -64,7 +65,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( private[mesos] val mesosExecutorCores = sc.conf.get(mesosConfig.EXECUTOR_CORES) // Offer constraints - private[this] val slaveOfferConstraints = + private[this] val agentOfferConstraints = parseConstraintString(sc.conf.get(mesosConfig.CONSTRAINTS)) // reject offers with mismatched constraints in seconds @@ -217,7 +218,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( val builder = new StringBuilder tasks.asScala.foreach { t => builder.append("Task id: ").append(t.getTaskId.getValue).append("\n") - .append("Slave id: ").append(t.getSlaveId.getValue).append("\n") + .append("Agent id: ").append(t.getSlaveId.getValue).append("\n") .append("Task resources: ").append(t.getResourcesList).append("\n") .append("Executor resources: ").append(t.getExecutor.getResourcesList) .append("---------------------------------------------\n") @@ -226,7 +227,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( } /** - * Method called by Mesos to offer resources on slaves. We respond by asking our active task sets + * Method called by Mesos to offer resources on agents. We respond by asking our active task sets * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that * tasks are balanced across the cluster. */ @@ -237,7 +238,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( offers.asScala.partition { o => val offerAttributes = toAttributeMap(o.getAttributesList) val meetsConstraints = - matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + matchesAttributeRequirements(agentOfferConstraints, offerAttributes) // add some debug messaging if (!meetsConstraints) { @@ -259,7 +260,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( val (usableOffers, unUsableOffers) = offersMatchingConstraints.partition { o => val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") - val slaveId = o.getSlaveId.getValue + val agentId = o.getSlaveId.getValue val offerAttributes = toAttributeMap(o.getAttributesList) // check offers for @@ -269,7 +270,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) val meetsRequirements = (meetsMemoryRequirements && meetsCPURequirements) || - (slaveIdToExecutorInfo.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) + (agentIdToExecutorInfo.contains(agentId) && cpus >= scheduler.CPUS_PER_TASK) val debugstr = if (meetsRequirements) "Accepting" else "Declining" logDebug(s"$debugstr offer: ${o.getId.getValue} with attributes: " + s"$offerAttributes mem: $mem cpu: $cpus") @@ -281,10 +282,10 @@ private[spark] class MesosFineGrainedSchedulerBackend( unUsableOffers.foreach(o => d.declineOffer(o.getId)) val workerOffers = usableOffers.map { o => - val cpus = if (slaveIdToExecutorInfo.contains(o.getSlaveId.getValue)) { + val cpus = if (agentIdToExecutorInfo.contains(o.getSlaveId.getValue)) { getResource(o.getResourcesList, "cpus").toInt } else { - // If the Mesos executor has not been started on this slave yet, set aside a few + // If the Mesos executor has not been started on this agent yet, set aside a few // cores for the Mesos executor by offering fewer cores to the Spark executor (getResource(o.getResourcesList, "cpus") - mesosExecutorCores).toInt } @@ -294,51 +295,51 @@ private[spark] class MesosFineGrainedSchedulerBackend( cpus) }.toIndexedSeq - val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap - val slaveIdToWorkerOffer = workerOffers.map(o => o.executorId -> o).toMap - val slaveIdToResources = new HashMap[String, JList[Resource]]() + val agentIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap + val agentIdToWorkerOffer = workerOffers.map(o => o.executorId -> o).toMap + val agentIdToResources = new HashMap[String, JList[Resource]]() usableOffers.foreach { o => - slaveIdToResources(o.getSlaveId.getValue) = o.getResourcesList + agentIdToResources(o.getSlaveId.getValue) = o.getResourcesList } val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]] - val slavesIdsOfAcceptedOffers = HashSet[String]() + val agentsIdsOfAcceptedOffers = HashSet[String]() // Call into the TaskSchedulerImpl val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty) acceptedOffers .foreach { offer => offer.foreach { taskDesc => - val slaveId = taskDesc.executorId - slavesIdsOfAcceptedOffers += slaveId - taskIdToSlaveId(taskDesc.taskId) = slaveId + val agentId = taskDesc.executorId + agentsIdsOfAcceptedOffers += agentId + taskIdToAgentId(taskDesc.taskId) = agentId val (mesosTask, remainingResources) = createMesosTask( taskDesc, - slaveIdToResources(slaveId), - slaveId) - mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) + agentIdToResources(agentId), + agentId) + mesosTasks.getOrElseUpdate(agentId, new JArrayList[MesosTaskInfo]) .add(mesosTask) - slaveIdToResources(slaveId) = remainingResources + agentIdToResources(agentId) = remainingResources } } // Reply to the offers val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? - mesosTasks.foreach { case (slaveId, tasks) => - slaveIdToWorkerOffer.get(slaveId).foreach(o => - listenerBus.post(SparkListenerExecutorAdded(System.currentTimeMillis(), slaveId, + mesosTasks.foreach { case (agentId, tasks) => + agentIdToWorkerOffer.get(agentId).foreach(o => + listenerBus.post(SparkListenerExecutorAdded(System.currentTimeMillis(), agentId, // TODO: Add support for log urls for Mesos new ExecutorInfo(o.host, o.cores, Map.empty, Map.empty))) ) - logTrace(s"Launching Mesos tasks on slave '$slaveId', tasks:\n${getTasksSummary(tasks)}") - d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) + logTrace(s"Launching Mesos tasks on agent '$agentId', tasks:\n${getTasksSummary(tasks)}") + d.launchTasks(Collections.singleton(agentIdToOffer(agentId).getId), tasks, filters) } // Decline offers that weren't used // NOTE: This logic assumes that we only get a single offer for each host in a given batch - for (o <- usableOffers if !slavesIdsOfAcceptedOffers.contains(o.getSlaveId.getValue)) { + for (o <- usableOffers if !agentsIdsOfAcceptedOffers.contains(o.getSlaveId.getValue)) { d.declineOffer(o.getId) } } @@ -348,19 +349,19 @@ private[spark] class MesosFineGrainedSchedulerBackend( def createMesosTask( task: TaskDescription, resources: JList[Resource], - slaveId: String): (MesosTaskInfo, JList[Resource]) = { + agentId: String): (MesosTaskInfo, JList[Resource]) = { val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build() - val (executorInfo, remainingResources) = if (slaveIdToExecutorInfo.contains(slaveId)) { - (slaveIdToExecutorInfo(slaveId), resources) + val (executorInfo, remainingResources) = if (agentIdToExecutorInfo.contains(agentId)) { + (agentIdToExecutorInfo(agentId), resources) } else { - createExecutorInfo(resources, slaveId) + createExecutorInfo(resources, agentId) } - slaveIdToExecutorInfo(slaveId) = executorInfo + agentIdToExecutorInfo(agentId) = executorInfo val (finalResources, cpuResources) = partitionResources(remainingResources, "cpus", scheduler.CPUS_PER_TASK) val taskInfo = MesosTaskInfo.newBuilder() .setTaskId(taskId) - .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) + .setSlaveId(AgentID.newBuilder().setValue(agentId).build()) .setExecutor(executorInfo) .setName(task.name) .addAllResources(cpuResources.asJava) @@ -375,12 +376,12 @@ private[spark] class MesosFineGrainedSchedulerBackend( val state = mesosToTaskState(status.getState) synchronized { if (TaskState.isFailed(mesosToTaskState(status.getState)) - && taskIdToSlaveId.contains(tid)) { - // We lost the executor on this slave, so remember that it's gone - removeExecutor(taskIdToSlaveId(tid), "Lost executor") + && taskIdToAgentId.contains(tid)) { + // We lost the executor on this agent, so remember that it's gone + removeExecutor(taskIdToAgentId(tid), "Lost executor") } if (TaskState.isFinished(state)) { - taskIdToSlaveId.remove(tid) + taskIdToAgentId.remove(tid) } } scheduler.statusUpdate(tid, state, status.getData.asReadOnlyByteBuffer) @@ -406,39 +407,39 @@ private[spark] class MesosFineGrainedSchedulerBackend( } override def frameworkMessage( - d: org.apache.mesos.SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]): Unit = {} + d: org.apache.mesos.SchedulerDriver, e: ExecutorID, s: AgentID, b: Array[Byte]): Unit = {} /** - * Remove executor associated with slaveId in a thread safe manner. + * Remove executor associated with agentId in a thread safe manner. */ - private def removeExecutor(slaveId: String, reason: String) = { + private def removeExecutor(agentId: String, reason: String) = { synchronized { - listenerBus.post(SparkListenerExecutorRemoved(System.currentTimeMillis(), slaveId, reason)) - slaveIdToExecutorInfo -= slaveId + listenerBus.post(SparkListenerExecutorRemoved(System.currentTimeMillis(), agentId, reason)) + agentIdToExecutorInfo -= agentId } } - private def recordSlaveLost( - d: org.apache.mesos.SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason): Unit = { + private def recordAgentLost( + d: org.apache.mesos.SchedulerDriver, agentId: AgentID, reason: ExecutorLossReason): Unit = { inClassLoader() { - logInfo("Mesos slave lost: " + slaveId.getValue) - removeExecutor(slaveId.getValue, reason.toString) - scheduler.executorLost(slaveId.getValue, reason) + logInfo("Mesos agent lost: " + agentId.getValue) + removeExecutor(agentId.getValue, reason.toString) + scheduler.executorLost(agentId.getValue, reason) } } - override def slaveLost(d: org.apache.mesos.SchedulerDriver, slaveId: SlaveID): Unit = { - recordSlaveLost(d, slaveId, SlaveLost()) + override def agentLost(d: org.apache.mesos.SchedulerDriver, agentId: AgentID): Unit = { + recordAgentLost(d, agentId, ExecutorProcessLost()) } override def executorLost( d: org.apache.mesos.SchedulerDriver, executorId: ExecutorID, - slaveId: SlaveID, + agentId: AgentID, status: Int): Unit = { - logInfo("Executor lost: %s, marking slave %s as lost".format(executorId.getValue, - slaveId.getValue)) - recordSlaveLost(d, slaveId, ExecutorExited(status, exitCausedByApp = true)) + logInfo("Executor lost: %s, marking agent %s as lost".format(executorId.getValue, + agentId.getValue)) + recordAgentLost(d, agentId, ExecutorExited(status, exitCausedByApp = true)) } override def killTask( diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosScheduler.scala new file mode 100644 index 0000000000000..f55b9efb3e64b --- /dev/null +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosScheduler.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.apache.mesos.Protos.{SlaveID => AgentID} + +trait MesosScheduler extends org.apache.mesos.Scheduler { + override def slaveLost(d: org.apache.mesos.SchedulerDriver, agentId: AgentID): Unit = { + agentLost(d, agentId) + } + + def agentLost(d: org.apache.mesos.SchedulerDriver, agentId: AgentID): Unit +} diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index 7b2f6a2535eda..981b8e9df1747 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -17,22 +17,23 @@ package org.apache.spark.scheduler.cluster.mesos -import org.apache.mesos.Protos.{ContainerInfo, Environment, Image, NetworkInfo, Parameter, Secret, Volume} +import org.apache.mesos.Protos.{ContainerInfo, Environment, Image, NetworkInfo, Parameter, Secret, + TaskState => MesosTaskState, Volume} import org.apache.mesos.Protos.ContainerInfo.{DockerInfo, MesosInfo} import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.protobuf.ByteString -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, TaskState} import org.apache.spark.SparkException import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.mesos.config.MesosSecretConfig import org.apache.spark.internal.Logging /** - * A collection of utility functions which can be used by both the - * MesosSchedulerBackend and the [[MesosFineGrainedSchedulerBackend]]. + * A collection of utility functions which can be used by the + * MesosSchedulerBackend, [[MesosFineGrainedSchedulerBackend]] and the MesosExecutorBackend. */ -private[mesos] object MesosSchedulerBackendUtil extends Logging { +private[spark] object MesosSchedulerBackendUtil extends Logging { /** * Parse a list of volume specs, each of which * takes the form [host-dir:]container-dir[:rw|:ro]. @@ -294,4 +295,13 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .setImage(imageProto) .build } + + def taskStateToMesos(state: TaskState.TaskState): MesosTaskState = state match { + case TaskState.LAUNCHING => MesosTaskState.TASK_STARTING + case TaskState.RUNNING => MesosTaskState.TASK_RUNNING + case TaskState.FINISHED => MesosTaskState.TASK_FINISHED + case TaskState.FAILED => MesosTaskState.TASK_FAILED + case TaskState.KILLED => MesosTaskState.TASK_KILLED + case TaskState.LOST => MesosTaskState.TASK_LOST + } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index ed3bd358d4082..5784ee314aa17 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -29,7 +29,7 @@ import scala.util.control.NonFatal import com.google.common.base.Splitter import com.google.common.io.Files import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler, SchedulerDriver} -import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} +import org.apache.mesos.Protos.{SlaveID => AgentID, TaskState => MesosTaskState, _} import org.apache.mesos.Protos.FrameworkInfo.Capability import org.apache.mesos.Protos.Resource.ReservationInfo import org.apache.mesos.protobuf.{ByteString, GeneratedMessageV3} @@ -304,12 +304,12 @@ trait MesosSchedulerUtils extends Logging { * Match the requirements (if any) to the offer attributes. * if attribute requirements are not specified - return true * else if attribute is defined and no values are given, simple attribute presence is performed - * else if attribute name and value is specified, subset match is performed on slave attributes + * else if attribute name and value is specified, subset match is performed on agent attributes */ def matchesAttributeRequirements( - slaveOfferConstraints: Map[String, Set[String]], + agentOfferConstraints: Map[String, Set[String]], offerAttributes: Map[String, GeneratedMessageV3]): Boolean = { - slaveOfferConstraints.forall { + agentOfferConstraints.forall { // offer has the required attribute and subsumes the required values for that attribute case (name, requiredValues) => offerAttributes.get(name) match { @@ -574,15 +574,6 @@ trait MesosSchedulerUtils extends Logging { MesosTaskState.TASK_UNREACHABLE => TaskState.LOST } - def taskStateToMesos(state: TaskState.TaskState): MesosTaskState = state match { - case TaskState.LAUNCHING => MesosTaskState.TASK_STARTING - case TaskState.RUNNING => MesosTaskState.TASK_RUNNING - case TaskState.FINISHED => MesosTaskState.TASK_FINISHED - case TaskState.FAILED => MesosTaskState.TASK_FAILED - case TaskState.KILLED => MesosTaskState.TASK_KILLED - case TaskState.LOST => MesosTaskState.TASK_LOST - } - protected def declineOffer( driver: org.apache.mesos.SchedulerDriver, offer: Offer, @@ -612,4 +603,3 @@ trait MesosSchedulerUtils extends Logging { } } } - diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index 9a50142b51d97..bb37bbd2d8046 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -413,7 +413,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi new MesosDriverDescription("d1", "jar", 100, 1, true, command, Map((config.EXECUTOR_HOME.key, "test"), ("spark.app.name", "test")), "s1", new Date())) assert(response.success) - val slaveId = SlaveID.newBuilder().setValue("s1").build() + val agentId = SlaveID.newBuilder().setValue("s1").build() val offer = Offer.newBuilder() .addResources( Resource.newBuilder().setRole("*") @@ -425,7 +425,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi .setType(Type.SCALAR)) .setId(OfferID.newBuilder().setValue("o1").build()) .setFrameworkId(FrameworkID.newBuilder().setValue("f1").build()) - .setSlaveId(slaveId) + .setSlaveId(agentId) .setHostname("host1") .build() // Offer the resource to launch the submitted driver @@ -438,7 +438,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi val taskStatus = TaskStatus.newBuilder() .setTaskId(TaskID.newBuilder().setValue(response.submissionId).build()) - .setSlaveId(slaveId) + .setSlaveId(agentId) .setState(MesosTaskState.TASK_KILLED) .build() // Update the status of the killed task diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 5ab277ed87a72..4d7f6441020b7 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -105,7 +105,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite backend.statusUpdate(driver, status) verify(driver, times(1)).reviveOffers() - // Launches a new task on a valid offer from the same slave + // Launches a new task on a valid offer from the same agent offerResources(List(offer2)) verifyTaskLaunched(driver, "o2") } @@ -250,7 +250,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite verifyTaskLaunched(driver, "o2") } - test("mesos creates multiple executors on a single slave") { + test("mesos creates multiple executors on a single agent") { val executorCores = 4 setBackend(Map(EXECUTOR_CORES.key -> executorCores.toString)) @@ -727,10 +727,10 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite private case class Resources(mem: Int, cpus: Int, gpus: Int = 0) - private def registerMockExecutor(executorId: String, slaveId: String, cores: Integer) = { + private def registerMockExecutor(executorId: String, agentId: String, cores: Integer) = { val mockEndpointRef = mock[RpcEndpointRef] val mockAddress = mock[RpcAddress] - val message = RegisterExecutor(executorId, mockEndpointRef, slaveId, cores, Map.empty, + val message = RegisterExecutor(executorId, mockEndpointRef, agentId, cores, Map.empty, Map.empty, Map.empty, ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) backend.driverEndpoint.askSync[Boolean](message) @@ -766,10 +766,10 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite } } - private def createTaskStatus(taskId: String, slaveId: String, state: TaskState): TaskStatus = { + private def createTaskStatus(taskId: String, agentId: String, state: TaskState): TaskStatus = { TaskStatus.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId).build()) - .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) + .setSlaveId(SlaveID.newBuilder().setValue(agentId).build()) .setState(state) .build } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index 5a4bf1dd2d409..92676cc4e7395 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -46,7 +46,7 @@ object Utils { def createOffer( offerId: String, - slaveId: String, + agentId: String, mem: Int, cpus: Int, ports: Option[(Long, Long)] = None, @@ -77,8 +77,8 @@ object Utils { builder.setId(createOfferId(offerId)) .setFrameworkId(FrameworkID.newBuilder() .setValue("f1")) - .setSlaveId(SlaveID.newBuilder().setValue(slaveId)) - .setHostname(s"host${slaveId}") + .setSlaveId(SlaveID.newBuilder().setValue(agentId)) + .setHostname(s"host${agentId}") .addAllAttributes(attributes.asJava) .build() } @@ -101,8 +101,8 @@ object Utils { OfferID.newBuilder().setValue(offerId).build() } - def createSlaveId(slaveId: String): SlaveID = { - SlaveID.newBuilder().setValue(slaveId).build() + def createAgentId(agentId: String): SlaveID = { + SlaveID.newBuilder().setValue(agentId).build() } def createExecutorId(executorId: String): ExecutorID = { @@ -227,4 +227,3 @@ object Utils { .build() } } - diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 7c67493c33160..2f272be60ba25 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -30,6 +30,7 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.util.Utils /** * Handles registering and unregistering the application with the YARN ResourceManager. @@ -107,7 +108,7 @@ private[spark] class YarnRMClient extends Logging { // so not all stable releases have it. val prefix = WebAppUtils.getHttpSchemePrefix(conf) val proxies = WebAppUtils.getProxyHostsAndPortsForAmFilter(conf) - val hosts = proxies.asScala.map(_.split(":").head) + val hosts = proxies.asScala.map(proxy => Utils.parseHostPort(proxy)._1) val uriBases = proxies.asScala.map { proxy => prefix + proxy + proxyBase } val params = Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(",")) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 9cefc4011c930..9d6b776a69d85 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -107,7 +107,7 @@ object YarnSparkHadoopUtil { * Not killing the task leaves various aspects of the executor and (to some extent) the jvm in * an inconsistent state. * TODO: If the OOM is not recoverable by rescheduling it on different node, then do - * 'something' to fail job ... akin to blacklisting trackers in mapred ? + * 'something' to fail job ... akin to unhealthy trackers in mapred ? * * The handler if an OOM Exception is thrown by the JVM must be configured on Windows * differently: the 'taskkill' command should be used, whereas Unix-based systems use 'kill'. diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 0475b0aed0ec4..3f2e8846e85b3 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -264,13 +264,14 @@ private[spark] abstract class YarnSchedulerBackend( case NonFatal(e) => logWarning(s"Attempted to get executor loss reason" + s" for executor id ${executorId} at RPC address ${executorRpcAddress}," + - s" but got no response. Marking as slave lost.", e) - RemoveExecutor(executorId, SlaveLost()) + s" but got no response. Marking as agent lost.", e) + RemoveExecutor(executorId, ExecutorProcessLost()) }(ThreadUtils.sameThread) case None => logWarning("Attempted to check for an executor loss reason" + " before the AM has registered!") - Future.successful(RemoveExecutor(executorId, SlaveLost("AM is not yet registered."))) + Future.successful(RemoveExecutor(executorId, + ExecutorProcessLost("AM is not yet registered."))) } removeExecutorMessage.foreach { message => driverEndpoint.send(message) } diff --git a/sbin/decommission-slave.sh b/sbin/decommission-slave.sh old mode 100644 new mode 100755 index 4bbf257ff1d3a..858bede1d2878 --- a/sbin/decommission-slave.sh +++ b/sbin/decommission-slave.sh @@ -17,41 +17,7 @@ # limitations under the License. # -# A shell script to decommission all workers on a single slave -# -# Environment variables -# -# SPARK_WORKER_INSTANCES The number of worker instances that should be -# running on this slave. Default is 1. - -# Usage: decommission-slave.sh [--block-until-exit] -# Decommissions all slaves on this worker machine - -set -ex - -if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" -fi - -. "${SPARK_HOME}/sbin/spark-config.sh" - -. "${SPARK_HOME}/bin/load-spark-env.sh" - -if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - "${SPARK_HOME}/sbin"/spark-daemon.sh decommission org.apache.spark.deploy.worker.Worker 1 -else - for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - "${SPARK_HOME}/sbin"/spark-daemon.sh decommission org.apache.spark.deploy.worker.Worker $(( $i + 1 )) - done -fi +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -# Check if --block-until-exit is set. -# This is done for systems which block on the decomissioning script and on exit -# shut down the entire system (e.g. K8s). -if [ "$1" == "--block-until-exit" ]; then - shift - # For now we only block on the 0th instance if there multiple instances. - instance=$1 - pid="$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command-$instance.pid" - wait $pid -fi +>&2 echo "This script is deprecated, use decommission-worker.sh" +"${DIR}/decommission-worker.sh" "$@" diff --git a/sbin/decommission-worker.sh b/sbin/decommission-worker.sh new file mode 100755 index 0000000000000..cf81a53f395c2 --- /dev/null +++ b/sbin/decommission-worker.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# A shell script to decommission all workers on a single worker +# +# Environment variables +# +# SPARK_WORKER_INSTANCES The number of worker instances that should be +# running on this worker machine. Default is 1. + +# Usage: decommission-worker.sh [--block-until-exit] +# Decommissions all workers on this worker machine. + +set -ex + +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + +. "${SPARK_HOME}/sbin/spark-config.sh" + +. "${SPARK_HOME}/bin/load-spark-env.sh" + +if [ "$SPARK_WORKER_INSTANCES" = "" ]; then + "${SPARK_HOME}/sbin"/spark-daemon.sh decommission org.apache.spark.deploy.worker.Worker 1 +else + for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do + "${SPARK_HOME}/sbin"/spark-daemon.sh decommission org.apache.spark.deploy.worker.Worker $(( $i + 1 )) + done +fi + +# Check if --block-until-exit is set. +# This is done for systems which block on the decomissioning script and on exit +# shut down the entire system (e.g. K8s). +if [ "$1" == "--block-until-exit" ]; then + shift + # For now we only block on the 0th instance if there multiple instances. + instance=$1 + pid="$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command-$instance.pid" + wait $pid +fi diff --git a/sbin/slaves.sh b/sbin/slaves.sh index c971aa3296b09..b92007ecdfad5 100755 --- a/sbin/slaves.sh +++ b/sbin/slaves.sh @@ -17,87 +17,7 @@ # limitations under the License. # -# Run a shell command on all slave hosts. -# -# Environment Variables -# -# SPARK_SLAVES File naming remote hosts. -# Default is ${SPARK_CONF_DIR}/slaves. -# SPARK_CONF_DIR Alternate conf dir. Default is ${SPARK_HOME}/conf. -# SPARK_SLAVE_SLEEP Seconds to sleep between spawning remote commands. -# SPARK_SSH_OPTS Options passed to ssh when running remote commands. -## - -usage="Usage: slaves.sh [--config ] command..." - -# if no args specified, show usage -if [ $# -le 0 ]; then - echo $usage - exit 1 -fi - -if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" -fi - -. "${SPARK_HOME}/sbin/spark-config.sh" - -# If the slaves file is specified in the command line, -# then it takes precedence over the definition in -# spark-env.sh. Save it here. -if [ -f "$SPARK_SLAVES" ]; then - HOSTLIST=`cat "$SPARK_SLAVES"` -fi - -# Check if --config is passed as an argument. It is an optional parameter. -# Exit if the argument is not a directory. -if [ "$1" == "--config" ] -then - shift - conf_dir="$1" - if [ ! -d "$conf_dir" ] - then - echo "ERROR : $conf_dir is not a directory" - echo $usage - exit 1 - else - export SPARK_CONF_DIR="$conf_dir" - fi - shift -fi - -. "${SPARK_HOME}/bin/load-spark-env.sh" - -if [ "$HOSTLIST" = "" ]; then - if [ "$SPARK_SLAVES" = "" ]; then - if [ -f "${SPARK_CONF_DIR}/slaves" ]; then - HOSTLIST=`cat "${SPARK_CONF_DIR}/slaves"` - else - HOSTLIST=localhost - fi - else - HOSTLIST=`cat "${SPARK_SLAVES}"` - fi -fi - - - -# By default disable strict host key checking -if [ "$SPARK_SSH_OPTS" = "" ]; then - SPARK_SSH_OPTS="-o StrictHostKeyChecking=no" -fi - -for slave in `echo "$HOSTLIST"|sed "s/#.*$//;/^$/d"`; do - if [ -n "${SPARK_SSH_FOREGROUND}" ]; then - ssh $SPARK_SSH_OPTS "$slave" $"${@// /\\ }" \ - 2>&1 | sed "s/^/$slave: /" - else - ssh $SPARK_SSH_OPTS "$slave" $"${@// /\\ }" \ - 2>&1 | sed "s/^/$slave: /" & - fi - if [ "$SPARK_SLAVE_SLEEP" != "" ]; then - sleep $SPARK_SLAVE_SLEEP - fi -done +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -wait +>&2 echo "This script is deprecated, use workers.sh" +"${DIR}/workers.sh" "$@" diff --git a/sbin/spark-daemons.sh b/sbin/spark-daemons.sh index dec2f4432df39..9a5e5f3a09c1d 100755 --- a/sbin/spark-daemons.sh +++ b/sbin/spark-daemons.sh @@ -17,7 +17,7 @@ # limitations under the License. # -# Run a Spark command on all slave hosts. +# Run a Spark command on all worker hosts. usage="Usage: spark-daemons.sh [--config ] [start|stop] command instance-number args..." @@ -33,4 +33,4 @@ fi . "${SPARK_HOME}/sbin/spark-config.sh" -exec "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin/spark-daemon.sh" "$@" +exec "${SPARK_HOME}/sbin/workers.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin/spark-daemon.sh" "$@" diff --git a/sbin/start-all.sh b/sbin/start-all.sh index a5d30d274ea6e..064074e07922b 100755 --- a/sbin/start-all.sh +++ b/sbin/start-all.sh @@ -19,7 +19,7 @@ # Start all spark daemons. # Starts the master on this node. -# Starts a worker on each node specified in conf/slaves +# Starts a worker on each node specified in conf/workers if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" @@ -32,4 +32,4 @@ fi "${SPARK_HOME}/sbin"/start-master.sh # Start Workers -"${SPARK_HOME}/sbin"/start-slaves.sh +"${SPARK_HOME}/sbin"/start-workers.sh diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh index 9b3b26b07842b..68682532f02ee 100755 --- a/sbin/start-slave.sh +++ b/sbin/start-slave.sh @@ -17,76 +17,7 @@ # limitations under the License. # -# Starts a slave on the machine this script is executed on. -# -# Environment Variables -# -# SPARK_WORKER_INSTANCES The number of worker instances to run on this -# slave. Default is 1. Note it has been deprecate since Spark 3.0. -# SPARK_WORKER_PORT The base port number for the first worker. If set, -# subsequent workers will increment this number. If -# unset, Spark will find a valid port number, but -# with no guarantee of a predictable pattern. -# SPARK_WORKER_WEBUI_PORT The base port for the web interface of the first -# worker. Subsequent workers will increment this -# number. Default is 8081. - -if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" -fi - -# NOTE: This exact class name is matched downstream by SparkSubmit. -# Any changes need to be reflected there. -CLASS="org.apache.spark.deploy.worker.Worker" - -if [[ $# -lt 1 ]] || [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - echo "Usage: ./sbin/start-slave.sh [options]" - pattern="Usage:" - pattern+="\|Using Spark's default log4j profile:" - pattern+="\|Started daemon with process name" - pattern+="\|Registered signal handler for" - - "${SPARK_HOME}"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 - exit 1 -fi - -. "${SPARK_HOME}/sbin/spark-config.sh" - -. "${SPARK_HOME}/bin/load-spark-env.sh" - -# First argument should be the master; we need to store it aside because we may -# need to insert arguments between it and the other arguments -MASTER=$1 -shift - -# Determine desired worker port -if [ "$SPARK_WORKER_WEBUI_PORT" = "" ]; then - SPARK_WORKER_WEBUI_PORT=8081 -fi - -# Start up the appropriate number of workers on this machine. -# quick local function to start a worker -function start_instance { - WORKER_NUM=$1 - shift - - if [ "$SPARK_WORKER_PORT" = "" ]; then - PORT_FLAG= - PORT_NUM= - else - PORT_FLAG="--port" - PORT_NUM=$(( $SPARK_WORKER_PORT + $WORKER_NUM - 1 )) - fi - WEBUI_PORT=$(( $SPARK_WORKER_WEBUI_PORT + $WORKER_NUM - 1 )) - - "${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS $WORKER_NUM \ - --webui-port "$WEBUI_PORT" $PORT_FLAG $PORT_NUM $MASTER "$@" -} +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - start_instance 1 "$@" -else - for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - start_instance $(( 1 + $i )) "$@" - done -fi +>&2 echo "This script is deprecated, use start-worker.sh" +"${DIR}/start-worker.sh" "$@" diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index f5269df523dac..9b113d9f2e0f4 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -17,30 +17,7 @@ # limitations under the License. # -# Starts a slave instance on each machine specified in the conf/slaves file. +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" -fi - -. "${SPARK_HOME}/sbin/spark-config.sh" -. "${SPARK_HOME}/bin/load-spark-env.sh" - -# Find the port number for the master -if [ "$SPARK_MASTER_PORT" = "" ]; then - SPARK_MASTER_PORT=7077 -fi - -if [ "$SPARK_MASTER_HOST" = "" ]; then - case `uname` in - (SunOS) - SPARK_MASTER_HOST="`/usr/sbin/check-hostname | awk '{print $NF}'`" - ;; - (*) - SPARK_MASTER_HOST="`hostname -f`" - ;; - esac -fi - -# Launch the slaves -"${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin/start-slave.sh" "spark://$SPARK_MASTER_HOST:$SPARK_MASTER_PORT" +>&2 echo "This script is deprecated, use start-workers.sh" +"${DIR}/start-workers.sh" "$@" diff --git a/sbin/start-worker.sh b/sbin/start-worker.sh new file mode 100755 index 0000000000000..fd58f01bac2eb --- /dev/null +++ b/sbin/start-worker.sh @@ -0,0 +1,92 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Starts a worker on the machine this script is executed on. +# +# Environment Variables +# +# SPARK_WORKER_INSTANCES The number of worker instances to run on this +# worker. Default is 1. Note it has been deprecate since Spark 3.0. +# SPARK_WORKER_PORT The base port number for the first worker. If set, +# subsequent workers will increment this number. If +# unset, Spark will find a valid port number, but +# with no guarantee of a predictable pattern. +# SPARK_WORKER_WEBUI_PORT The base port for the web interface of the first +# worker. Subsequent workers will increment this +# number. Default is 8081. + +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + +# NOTE: This exact class name is matched downstream by SparkSubmit. +# Any changes need to be reflected there. +CLASS="org.apache.spark.deploy.worker.Worker" + +if [[ $# -lt 1 ]] || [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/start-worker.sh [options]" + pattern="Usage:" + pattern+="\|Using Spark's default log4j profile:" + pattern+="\|Started daemon with process name" + pattern+="\|Registered signal handler for" + + "${SPARK_HOME}"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + exit 1 +fi + +. "${SPARK_HOME}/sbin/spark-config.sh" + +. "${SPARK_HOME}/bin/load-spark-env.sh" + +# First argument should be the master; we need to store it aside because we may +# need to insert arguments between it and the other arguments +MASTER=$1 +shift + +# Determine desired worker port +if [ "$SPARK_WORKER_WEBUI_PORT" = "" ]; then + SPARK_WORKER_WEBUI_PORT=8081 +fi + +# Start up the appropriate number of workers on this machine. +# quick local function to start a worker +function start_instance { + WORKER_NUM=$1 + shift + + if [ "$SPARK_WORKER_PORT" = "" ]; then + PORT_FLAG= + PORT_NUM= + else + PORT_FLAG="--port" + PORT_NUM=$(( $SPARK_WORKER_PORT + $WORKER_NUM - 1 )) + fi + WEBUI_PORT=$(( $SPARK_WORKER_WEBUI_PORT + $WORKER_NUM - 1 )) + + "${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS $WORKER_NUM \ + --webui-port "$WEBUI_PORT" $PORT_FLAG $PORT_NUM $MASTER "$@" +} + +if [ "$SPARK_WORKER_INSTANCES" = "" ]; then + start_instance 1 "$@" +else + for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do + start_instance $(( 1 + $i )) "$@" + done +fi diff --git a/sbin/start-workers.sh b/sbin/start-workers.sh new file mode 100755 index 0000000000000..3867ef3ccf255 --- /dev/null +++ b/sbin/start-workers.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Starts a worker instance on each machine specified in the conf/workers file. + +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + +. "${SPARK_HOME}/sbin/spark-config.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" + +# Find the port number for the master +if [ "$SPARK_MASTER_PORT" = "" ]; then + SPARK_MASTER_PORT=7077 +fi + +if [ "$SPARK_MASTER_HOST" = "" ]; then + case `uname` in + (SunOS) + SPARK_MASTER_HOST="`/usr/sbin/check-hostname | awk '{print $NF}'`" + ;; + (*) + SPARK_MASTER_HOST="`hostname -f`" + ;; + esac +fi + +# Launch the workers +"${SPARK_HOME}/sbin/workers.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin/start-worker.sh" "spark://$SPARK_MASTER_HOST:$SPARK_MASTER_PORT" diff --git a/sbin/stop-all.sh b/sbin/stop-all.sh index 4e476ca05cb05..2c40905cd499b 100755 --- a/sbin/stop-all.sh +++ b/sbin/stop-all.sh @@ -27,8 +27,8 @@ fi # Load the Spark configuration . "${SPARK_HOME}/sbin/spark-config.sh" -# Stop the slaves, then the master -"${SPARK_HOME}/sbin"/stop-slaves.sh +# Stop the workers, then the master +"${SPARK_HOME}/sbin"/stop-workers.sh "${SPARK_HOME}/sbin"/stop-master.sh if [ "$1" == "--wait" ] @@ -36,7 +36,7 @@ then printf "Waiting for workers to shut down..." while true do - running=`${SPARK_HOME}/sbin/slaves.sh ps -ef | grep -v grep | grep deploy.worker.Worker` + running=`${SPARK_HOME}/sbin/workers.sh ps -ef | grep -v grep | grep deploy.worker.Worker` if [ -z "$running" ] then printf "\nAll workers successfully shut down.\n" diff --git a/sbin/stop-slave.sh b/sbin/stop-slave.sh index 685bcf59b33aa..71ed29987d4a1 100755 --- a/sbin/stop-slave.sh +++ b/sbin/stop-slave.sh @@ -17,28 +17,7 @@ # limitations under the License. # -# A shell script to stop all workers on a single slave -# -# Environment variables -# -# SPARK_WORKER_INSTANCES The number of worker instances that should be -# running on this slave. Default is 1. - -# Usage: stop-slave.sh -# Stops all slaves on this worker machine - -if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" -fi - -. "${SPARK_HOME}/sbin/spark-config.sh" - -. "${SPARK_HOME}/bin/load-spark-env.sh" +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - "${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker 1 -else - for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - "${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 )) - done -fi +>&2 echo "This script is deprecated, use stop-worker.sh" +"${DIR}/stop-worker.sh" "$@" diff --git a/sbin/stop-slaves.sh b/sbin/stop-slaves.sh index a57441b52a04a..c0aca6868efe3 100755 --- a/sbin/stop-slaves.sh +++ b/sbin/stop-slaves.sh @@ -17,12 +17,7 @@ # limitations under the License. # -if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" -fi +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -. "${SPARK_HOME}/sbin/spark-config.sh" - -. "${SPARK_HOME}/bin/load-spark-env.sh" - -"${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/stop-slave.sh +>&2 echo "This script is deprecated, use stop-workers.sh" +"${DIR}/stop-workers.sh" "$@" diff --git a/sbin/stop-worker.sh b/sbin/stop-worker.sh new file mode 100755 index 0000000000000..112b62ecffa27 --- /dev/null +++ b/sbin/stop-worker.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# A shell script to stop all workers on a single worker +# +# Environment variables +# +# SPARK_WORKER_INSTANCES The number of worker instances that should be +# running on this worker machine. Default is 1. + +# Usage: stop-worker.sh +# Stops all workers on this worker machine + +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + +. "${SPARK_HOME}/sbin/spark-config.sh" + +. "${SPARK_HOME}/bin/load-spark-env.sh" + +if [ "$SPARK_WORKER_INSTANCES" = "" ]; then + "${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker 1 +else + for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do + "${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 )) + done +fi diff --git a/sbin/stop-workers.sh b/sbin/stop-workers.sh new file mode 100755 index 0000000000000..552800f522222 --- /dev/null +++ b/sbin/stop-workers.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + +. "${SPARK_HOME}/sbin/spark-config.sh" + +. "${SPARK_HOME}/bin/load-spark-env.sh" + +"${SPARK_HOME}/sbin/workers.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/stop-worker.sh diff --git a/sbin/workers.sh b/sbin/workers.sh new file mode 100755 index 0000000000000..cab0330723a6c --- /dev/null +++ b/sbin/workers.sh @@ -0,0 +1,120 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Run a shell command on all worker hosts. +# +# Environment Variables +# +# SPARK_WORKERS File naming remote hosts. +# Default is ${SPARK_CONF_DIR}/workers. +# SPARK_CONF_DIR Alternate conf dir. Default is ${SPARK_HOME}/conf. +# SPARK_WORKER_SLEEP Seconds to sleep between spawning remote commands. +# SPARK_SSH_OPTS Options passed to ssh when running remote commands. +## + +usage="Usage: workers.sh [--config ] command..." + +# if no args specified, show usage +if [ $# -le 0 ]; then + echo $usage + exit 1 +fi + +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + +. "${SPARK_HOME}/sbin/spark-config.sh" + +# If the workers file is specified in the command line, +# then it takes precedence over the definition in +# spark-env.sh. Save it here. +if [ -f "$SPARK_WORKERS" ]; then + HOSTLIST=`cat "$SPARK_WORKERS"` +fi +if [ -f "$SPARK_SLAVES" ]; then + >&2 echo "SPARK_SLAVES is deprecated, use SPARK_WORKERS" + HOSTLIST=`cat "$SPARK_SLAVES"` +fi + + +# Check if --config is passed as an argument. It is an optional parameter. +# Exit if the argument is not a directory. +if [ "$1" == "--config" ] +then + shift + conf_dir="$1" + if [ ! -d "$conf_dir" ] + then + echo "ERROR : $conf_dir is not a directory" + echo $usage + exit 1 + else + export SPARK_CONF_DIR="$conf_dir" + fi + shift +fi + +. "${SPARK_HOME}/bin/load-spark-env.sh" + +if [ "$HOSTLIST" = "" ]; then + if [ "$SPARK_SLAVES" = "" ] && [ "$SPARK_WORKERS" = "" ]; then + if [ -f "${SPARK_CONF_DIR}/workers" ]; then + HOSTLIST=`cat "${SPARK_CONF_DIR}/workers"` + elif [ -f "${SPARK_CONF_DIR}/slaves" ]; then + HOSTLIST=`cat "${SPARK_CONF_DIR}/slaves"` + else + HOSTLIST=localhost + fi + else + if [ -f "$SPARK_WORKERS" ]; then + HOSTLIST=`cat "$SPARK_WORKERS"` + fi + if [ -f "$SPARK_SLAVES" ]; then + >&2 echo "SPARK_SLAVES is deprecated, use SPARK_WORKERS" + HOSTLIST=`cat "$SPARK_SLAVES"` + fi + fi +fi + + + +# By default disable strict host key checking +if [ "$SPARK_SSH_OPTS" = "" ]; then + SPARK_SSH_OPTS="-o StrictHostKeyChecking=no" +fi + +for host in `echo "$HOSTLIST"|sed "s/#.*$//;/^$/d"`; do + if [ -n "${SPARK_SSH_FOREGROUND}" ]; then + ssh $SPARK_SSH_OPTS "$host" $"${@// /\\ }" \ + 2>&1 | sed "s/^/$host: /" + else + ssh $SPARK_SSH_OPTS "$host" $"${@// /\\ }" \ + 2>&1 | sed "s/^/$host: /" & + fi + if [ "$SPARK_WORKER_SLEEP" != "" ]; then + sleep $SPARK_WORKER_SLEEP + fi + if [ "$SPARK_SLAVE_SLEEP" != "" ]; then + >&2 echo "SPARK_SLAVE_SLEEP is deprecated, use SPARK_WORKER_SLEEP" + sleep $SPARK_SLAVE_SLEEP + fi +done + +wait diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 691fde8d48f94..d29fa1319daa5 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -988,6 +988,7 @@ number | MINUS? SMALLINT_LITERAL #smallIntLiteral | MINUS? TINYINT_LITERAL #tinyIntLiteral | MINUS? DOUBLE_LITERAL #doubleLiteral + | MINUS? FLOAT_LITERAL #floatLiteral | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral ; @@ -1461,8 +1462,7 @@ nonReserved ; // NOTE: If you add a new token in the list below, you should update the list of keywords -// in `docs/sql-keywords.md`. If the token is a non-reserved keyword, -// please update `ansiNonReserved` and `nonReserved` as well. +// and reserved tag in `docs/sql-ref-ansi-compliance.md#sql-keywords`. //============================ // Start of the keywords list @@ -1532,6 +1532,7 @@ DIRECTORIES: 'DIRECTORIES'; DIRECTORY: 'DIRECTORY'; DISTINCT: 'DISTINCT'; DISTRIBUTE: 'DISTRIBUTE'; +DIV: 'DIV'; DROP: 'DROP'; ELSE: 'ELSE'; END: 'END'; @@ -1739,7 +1740,6 @@ MINUS: '-'; ASTERISK: '*'; SLASH: '/'; PERCENT: '%'; -DIV: 'DIV'; TILDE: '~'; AMPERSAND: '&'; PIPE: '|'; @@ -1776,6 +1776,11 @@ DECIMAL_VALUE : DECIMAL_DIGITS {isValidDecimal()}? ; +FLOAT_LITERAL + : DIGIT+ EXPONENT? 'F' + | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}? + ; + DOUBLE_LITERAL : DIGIT+ EXPONENT? 'D' | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 034894bd86085..4dc5ce1de047b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -288,7 +288,7 @@ public void setDecimal(int ordinal, Decimal value, int precision) { Platform.putLong(baseObject, baseOffset + cursor, 0L); Platform.putLong(baseObject, baseOffset + cursor + 8, 0L); - if (value == null) { + if (value == null || !value.changePrecision(precision, value.scale())) { setNullAt(ordinal); // keep the offset for future update Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d08a6382f738b..023ef2ee17473 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -200,6 +200,8 @@ class Analyzer( val postHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil lazy val batches: Seq[Batch] = Seq( + Batch("Disable Hints", Once, + new ResolveHints.DisableHints(conf)), Batch("Hints", fixedPoint, new ResolveHints.ResolveJoinStrategyHints(conf), new ResolveHints.ResolveCoalesceHints(conf)), @@ -1048,14 +1050,12 @@ class Analyzer( val partCols = partitionColumnNames(r.table) validatePartitionSpec(partCols, i.partitionSpec) - val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get) + val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get).toMap val query = addStaticPartitionColumns(r, i.query, staticPartitions) - val dynamicPartitionOverwrite = partCols.size > staticPartitions.size && - conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC if (!i.overwrite) { AppendData.byPosition(r, query) - } else if (dynamicPartitionOverwrite) { + } else if (conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC) { OverwritePartitionsDynamic.byPosition(r, query) } else { OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r, staticPartitions)) @@ -2238,7 +2238,7 @@ class Analyzer( } } if (aggregateExpressions.nonEmpty) { - Some(aggregateExpressions, transformedAggregateFilter) + Some(aggregateExpressions.toSeq, transformedAggregateFilter) } else { None } @@ -2677,7 +2677,7 @@ class Analyzer( val windowOps = groupedWindowExpressions.foldLeft(child) { case (last, ((partitionSpec, orderSpec, _), windowExpressions)) => - Window(windowExpressions, partitionSpec, orderSpec, last) + Window(windowExpressions.toSeq, partitionSpec, orderSpec, last) } // Finally, we create a Project to output windowOps's output @@ -2819,13 +2819,12 @@ class Analyzer( case p => p transformExpressionsUp { - case udf @ ScalaUDF(_, _, inputs, _, _, _, _) - if udf.inputPrimitives.contains(true) => + case udf: ScalaUDF if udf.inputPrimitives.contains(true) => // Otherwise, add special handling of null for fields that can't accept null. // The result of operations like this, when passed null, is generally to return null. - assert(udf.inputPrimitives.length == inputs.length) + assert(udf.inputPrimitives.length == udf.children.length) - val inputPrimitivesPair = udf.inputPrimitives.zip(inputs) + val inputPrimitivesPair = udf.inputPrimitives.zip(udf.children) val inputNullCheck = inputPrimitivesPair.collect { case (isPrimitive, input) if isPrimitive && input.nullable => IsNull(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index 9f0eff5017f38..623cd131bf8da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -72,10 +72,10 @@ object CTESubstitution extends Rule[LogicalPlan] { } // CTE relation is defined as `SubqueryAlias`. Here we skip it and check the child // directly, so that `startOfQuery` is set correctly. - assertNoNameConflictsInCTE(relation.child, newNames) + assertNoNameConflictsInCTE(relation.child, newNames.toSeq) newNames += name } - assertNoNameConflictsInCTE(child, newNames, startOfQuery = false) + assertNoNameConflictsInCTE(child, newNames.toSeq, startOfQuery = false) case other => other.subqueries.foreach(assertNoNameConflictsInCTE(_, outerCTERelationNames)) @@ -162,9 +162,9 @@ object CTESubstitution extends Rule[LogicalPlan] { traverseAndSubstituteCTE(relation) } // CTE definition can reference a previous one - resolvedCTERelations += (name -> substituteCTE(innerCTEResolved, resolvedCTERelations)) + resolvedCTERelations += (name -> substituteCTE(innerCTEResolved, resolvedCTERelations.toSeq)) } - resolvedCTERelations + resolvedCTERelations.toSeq } private def substituteCTE( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 9c99acaa994b8..351be32ee438e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -158,6 +158,11 @@ trait CheckAnalysis extends PredicateHelper { case g: GroupingID => failAnalysis("grouping_id() can only be used with GroupingSets/Cube/Rollup") + case e: Expression if e.children.exists(_.isInstanceOf[WindowFunction]) && + !e.isInstanceOf[WindowExpression] => + val w = e.children.find(_.isInstanceOf[WindowFunction]).get + failAnalysis(s"Window function $w requires an OVER clause.") + case w @ WindowExpression(AggregateExpression(_, _, true, _, _), _) => failAnalysis(s"Distinct window functions are not supported: $w") @@ -862,7 +867,7 @@ trait CheckAnalysis extends PredicateHelper { // Simplify the predicates before validating any unsupported correlation patterns in the plan. AnalysisHelper.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp { - // Whitelist operators allowed in a correlated subquery + // Approve operators allowed in a correlated subquery // There are 4 categories: // 1. Operators that are allowed anywhere in a correlated subquery, and, // by definition of the operators, they either do not contain diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 2a0a944e4849c..a40604045978c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -34,6 +34,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case AlterTableAddColumnsStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), cols) => + cols.foreach(c => failNullType(c.dataType)) cols.foreach(c => failCharType(c.dataType)) val changes = cols.map { col => TableChange.addColumn( @@ -47,6 +48,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case AlterTableReplaceColumnsStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), cols) => + cols.foreach(c => failNullType(c.dataType)) cols.foreach(c => failCharType(c.dataType)) val changes: Seq[TableChange] = loadTable(catalog, tbl.asIdentifier) match { case Some(table) => @@ -69,6 +71,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case a @ AlterTableAlterColumnStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _) => + a.dataType.foreach(failNullType) a.dataType.foreach(failCharType) val colName = a.column.toArray val typeChange = a.dataType.map { newDataType => @@ -145,6 +148,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case c @ CreateTableStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _) => + assertNoNullTypeInSchema(c.tableSchema) assertNoCharTypeInSchema(c.tableSchema) CreateV2Table( catalog.asTableCatalog, @@ -157,6 +161,9 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case c @ CreateTableAsSelectStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _) => + if (c.asSelect.resolved) { + assertNoNullTypeInSchema(c.asSelect.schema) + } CreateTableAsSelect( catalog.asTableCatalog, tbl.asIdentifier, @@ -172,6 +179,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case c @ ReplaceTableStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _) => + assertNoNullTypeInSchema(c.tableSchema) assertNoCharTypeInSchema(c.tableSchema) ReplaceTable( catalog.asTableCatalog, @@ -184,6 +192,9 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case c @ ReplaceTableAsSelectStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _) => + if (c.asSelect.resolved) { + assertNoNullTypeInSchema(c.asSelect.schema) + } ReplaceTableAsSelect( catalog.asTableCatalog, tbl.asIdentifier, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 4cbff62e16cc1..120842b0c4a07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -278,4 +278,15 @@ object ResolveHints { h.child } } + + /** + * Removes all the hints when `spark.sql.optimizer.disableHints` is set. + * This is executed at the very beginning of the Analyzer to disable + * the hint functionality. + */ + class DisableHints(conf: SQLConf) extends RemoveAllHints(conf: SQLConf) { + override def apply(plan: LogicalPlan): LogicalPlan = { + if (conf.getConf(SQLConf.DISABLE_HINTS)) super.apply(plan) else plan + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 423f89fefa093..0c11830cf06dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -393,17 +393,6 @@ object UnsupportedOperationChecker extends Logging { _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias | _: TypedFilter) => case node if node.nodeName == "StreamingRelationV2" => - case Repartition(1, false, _) => - case node: Aggregate => - val aboveSinglePartitionCoalesce = node.find { - case Repartition(1, false, _) => true - case _ => false - }.isDefined - - if (!aboveSinglePartitionCoalesce) { - throwError(s"In continuous processing mode, coalesce(1) must be called before " + - s"aggregate operation ${node.nodeName}.") - } case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index f08416fcaba8a..3d5c1855f6975 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -24,6 +24,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.ScalaReflection.Schema import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer} import org.apache.spark.sql.catalyst.expressions._ @@ -32,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Initial import org.apache.spark.sql.catalyst.optimizer.{ReassignLambdaVariableID, SimplifyCasts} import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LeafNode, LocalRelation} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ObjectType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, ObjectType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -305,6 +306,11 @@ case class ExpressionEncoder[T]( StructField(s.name, s.dataType, s.nullable) }) + def dataTypeAndNullable: Schema = { + val dataType = if (isSerializedAsStruct) schema else schema.head.dataType + Schema(dataType, objSerializer.nullable) + } + /** * Returns true if the type `T` is serialized as a struct by `objSerializer`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index a32052ce121df..458c48df6d0c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -106,7 +106,7 @@ class EquivalentExpressions { * an empty collection if there are none. */ def getEquivalentExprs(e: Expression): Seq[Expression] = { - equivalenceMap.getOrElse(Expr(e), Seq.empty) + equivalenceMap.getOrElse(Expr(e), Seq.empty).toSeq } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 3d10b084a8db1..6e2bd96784b94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType, UserDefinedType} +import org.apache.spark.util.Utils /** * User-defined function. @@ -36,6 +37,8 @@ import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType, User * @param inputEncoders ExpressionEncoder for each input parameters. For a input parameter which * serialized as struct will use encoder instead of CatalystTypeConverters to * convert internal value to Scala value. + * @param outputEncoder ExpressionEncoder for the return type of function. It's only defined when + * this is a typed Scala UDF. * @param udfName The user-specified name of this UDF. * @param nullable True if the UDF can return null value. * @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result @@ -46,6 +49,7 @@ case class ScalaUDF( dataType: DataType, children: Seq[Expression], inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil, + outputEncoder: Option[ExpressionEncoder[_]] = None, udfName: Option[String] = None, nullable: Boolean = true, udfDeterministic: Boolean = true) @@ -55,6 +59,12 @@ case class ScalaUDF( override def toString: String = s"${udfName.getOrElse("UDF")}(${children.mkString(", ")})" + override lazy val canonicalized: Expression = { + // SPARK-32307: `ExpressionEncoder` can't be canonicalized, and technically we don't + // need it to identify a `ScalaUDF`. + Canonicalize.execute(copy(children = children.map(_.canonicalized), inputEncoders = Nil)) + } + /** * The analyzer should be aware of Scala primitive types so as to make the * UDF return null if there is any null input value of these types. On the @@ -62,7 +72,7 @@ case class ScalaUDF( * Nil(has same effect with all false) and analyzer will skip null-handling * on them. */ - def inputPrimitives: Seq[Boolean] = { + lazy val inputPrimitives: Seq[Boolean] = { inputEncoders.map { encoderOpt => // It's possible that some of the inputs don't have a specific encoder(e.g. `Any`) if (encoderOpt.isDefined) { @@ -102,6 +112,23 @@ case class ScalaUDF( } } + /** + * Create the converter which converts the scala data type to the catalyst data type for + * the return data type of udf function. We'd use `ExpressionEncoder` to create the + * converter for typed ScalaUDF only, since its the only case where we know the type tag + * of the return data type of udf function. + */ + private def catalystConverter: Any => Any = outputEncoder.map { enc => + val toRow = enc.createSerializer().asInstanceOf[Any => Any] + if (enc.isSerializedAsStruct) { + value: Any => + if (value == null) null else toRow(value).asInstanceOf[InternalRow] + } else { + value: Any => + if (value == null) null else toRow(value).asInstanceOf[InternalRow].get(0, dataType) + } + }.getOrElse(createToCatalystConverter(dataType)) + /** * Create the converter which converts the catalyst data type to the scala data type. * We use `CatalystTypeConverters` to create the converter for: @@ -1071,7 +1098,7 @@ case class ScalaUDF( val (converters, useEncoders): (Array[Any => Any], Array[Boolean]) = (children.zipWithIndex.map { case (c, i) => scalaConverter(i, c.dataType) - }.toArray :+ (createToCatalystConverter(dataType), false)).unzip + }.toArray :+ (catalystConverter, false)).unzip val convertersTerm = ctx.addReferenceObj("converters", converters, s"$converterClassName[]") val errorMsgTerm = ctx.addReferenceObj("errMsg", udfErrorMessage) val resultTerm = ctx.freshName("result") @@ -1149,10 +1176,10 @@ case class ScalaUDF( """.stripMargin) } - private[this] val resultConverter = createToCatalystConverter(dataType) + private[this] val resultConverter = catalystConverter lazy val udfErrorMessage = { - val funcCls = function.getClass.getSimpleName + val funcCls = Utils.getSimpleName(function.getClass) val inputTypes = children.map(_.dataType.catalogString).mkString(", ") val outputType = dataType.catalogString s"Failed to execute user defined function($funcCls: ($inputTypes) => $outputType)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 6e850267100fb..a29ae2c8b65a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -58,13 +58,11 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast case _ => DoubleType } - private lazy val sumDataType = resultType - - private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val sum = AttributeReference("sum", resultType)() private lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)() - private lazy val zero = Literal.default(sumDataType) + private lazy val zero = Literal.default(resultType) override lazy val aggBufferAttributes = resultType match { case _: DecimalType => sum :: isEmpty :: Nil @@ -72,25 +70,38 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast } override lazy val initialValues: Seq[Expression] = resultType match { - case _: DecimalType => Seq(Literal(null, resultType), Literal(true, BooleanType)) + case _: DecimalType => Seq(zero, Literal(true, BooleanType)) case _ => Seq(Literal(null, resultType)) } override lazy val updateExpressions: Seq[Expression] = { - if (child.nullable) { - val updateSumExpr = coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) - resultType match { - case _: DecimalType => - Seq(updateSumExpr, isEmpty && child.isNull) - case _ => Seq(updateSumExpr) - } - } else { - val updateSumExpr = coalesce(sum, zero) + child.cast(sumDataType) - resultType match { - case _: DecimalType => - Seq(updateSumExpr, Literal(false, BooleanType)) - case _ => Seq(updateSumExpr) - } + resultType match { + case _: DecimalType => + // For decimal type, the initial value of `sum` is 0. We need to keep `sum` unchanged if + // the input is null, as SUM function ignores null input. The `sum` can only be null if + // overflow happens under non-ansi mode. + val sumExpr = if (child.nullable) { + If(child.isNull, sum, sum + KnownNotNull(child).cast(resultType)) + } else { + sum + child.cast(resultType) + } + // The buffer becomes non-empty after seeing the first not-null input. + val isEmptyExpr = if (child.nullable) { + isEmpty && child.isNull + } else { + Literal(false, BooleanType) + } + Seq(sumExpr, isEmptyExpr) + case _ => + // For non-decimal type, the initial value of `sum` is null, which indicates no value. + // We need `coalesce(sum, zero)` to start summing values. And we need an outer `coalesce` + // in case the input is nullable. The `sum` can only be null if there is no value, as + // non-decimal type can produce overflowed value under non-ansi mode. + if (child.nullable) { + Seq(coalesce(coalesce(sum, zero) + child.cast(resultType), sum)) + } else { + Seq(coalesce(sum, zero) + child.cast(resultType)) + } } } @@ -107,15 +118,20 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast * means we have seen atleast a value that was not null. */ override lazy val mergeExpressions: Seq[Expression] = { - val mergeSumExpr = coalesce(coalesce(sum.left, zero) + sum.right, sum.left) resultType match { case _: DecimalType => - val inputOverflow = !isEmpty.right && sum.right.isNull val bufferOverflow = !isEmpty.left && sum.left.isNull + val inputOverflow = !isEmpty.right && sum.right.isNull Seq( - If(inputOverflow || bufferOverflow, Literal.create(null, sumDataType), mergeSumExpr), + If( + bufferOverflow || inputOverflow, + Literal.create(null, resultType), + // If both the buffer and the input do not overflow, just add them, as they can't be + // null. See the comments inside `updateExpressions`: `sum` can only be null if + // overflow happens. + KnownNotNull(sum.left) + KnownNotNull(sum.right)), isEmpty.left && isEmpty.right) - case _ => Seq(mergeSumExpr) + case _ => Seq(coalesce(coalesce(sum.left, zero) + sum.right, sum.left)) } } @@ -128,7 +144,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast */ override lazy val evaluateExpression: Expression = resultType match { case d: DecimalType => - If(isEmpty, Literal.create(null, sumDataType), + If(isEmpty, Literal.create(null, resultType), CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled)) case _ => sum } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 817dd948f1a6a..9c20916790c21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -375,7 +375,7 @@ class CodegenContext extends Logging { // The generated initialization code may exceed 64kb function size limit in JVM if there are too // many mutable states, so split it into multiple functions. - splitExpressions(expressions = initCodes, funcName = "init", arguments = Nil) + splitExpressions(expressions = initCodes.toSeq, funcName = "init", arguments = Nil) } /** @@ -927,6 +927,7 @@ class CodegenContext extends Logging { length += CodeFormatter.stripExtraNewLinesAndComments(code).length } blocks += blockBuilder.toString() + blocks.toSeq } /** @@ -1002,7 +1003,7 @@ class CodegenContext extends Logging { def subexprFunctionsCode: String = { // Whole-stage codegen's subexpression elimination is handled in another code path assert(currentVars == null || subexprFunctions.isEmpty) - splitExpressions(subexprFunctions, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW)) + splitExpressions(subexprFunctions.toSeq, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 026a2a677baec..74c9b12a109d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2612,10 +2612,17 @@ object Sequence { val stepDays = step.days val stepMicros = step.microseconds + if (scale == MICROS_PER_DAY && stepMonths == 0 && stepDays == 0) { + throw new IllegalArgumentException( + "sequence step must be a day interval if start and end values are dates") + } + if (stepMonths == 0 && stepMicros == 0 && scale == MICROS_PER_DAY) { + // Adding pure days to date start/end backedSequenceImpl.eval(start, stop, fromLong(stepDays)) } else if (stepMonths == 0 && stepDays == 0 && scale == 1) { + // Adding pure microseconds to timestamp start/end backedSequenceImpl.eval(start, stop, fromLong(stepMicros)) } else { @@ -2674,11 +2681,24 @@ object Sequence { |${genSequenceLengthCode(ctx, startMicros, stopMicros, intervalInMicros, arrLength)} """.stripMargin + val check = if (scale == MICROS_PER_DAY) { + s""" + |if ($stepMonths == 0 && $stepDays == 0) { + | throw new IllegalArgumentException( + | "sequence step must be a day interval if start and end values are dates"); + |} + """.stripMargin + } else { + "" + } + s""" |final int $stepMonths = $step.months; |final int $stepDays = $step.days; |final long $stepMicros = $step.microseconds; | + |$check + | |if ($stepMonths == 0 && $stepMicros == 0 && ${scale}L == ${MICROS_PER_DAY}L) { | ${backedSequenceImpl.genCode(ctx, start, stop, stepDays, arr, elemType)}; | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 1b4a705e804f1..cf7cc3a5e16ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -539,3 +539,61 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E override def prettyName: String = "str_to_map" } + +/** + * Adds/replaces field in struct by name. + */ +case class WithFields( + structExpr: Expression, + names: Seq[String], + valExprs: Seq[Expression]) extends Unevaluable { + + assert(names.length == valExprs.length) + + override def checkInputDataTypes(): TypeCheckResult = { + if (!structExpr.dataType.isInstanceOf[StructType]) { + TypeCheckResult.TypeCheckFailure( + "struct argument should be struct type, got: " + structExpr.dataType.catalogString) + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def children: Seq[Expression] = structExpr +: valExprs + + override def dataType: StructType = evalExpr.dataType.asInstanceOf[StructType] + + override def foldable: Boolean = structExpr.foldable && valExprs.forall(_.foldable) + + override def nullable: Boolean = structExpr.nullable + + override def prettyName: String = "with_fields" + + lazy val evalExpr: Expression = { + val existingExprs = structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map { + case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i).asInstanceOf[Expression]) + } + + val addOrReplaceExprs = names.zip(valExprs) + + val resolver = SQLConf.get.resolver + val newExprs = addOrReplaceExprs.foldLeft(existingExprs) { + case (resultExprs, newExpr @ (newExprName, _)) => + if (resultExprs.exists(x => resolver(x._1, newExprName))) { + resultExprs.map { + case (name, _) if resolver(name, newExprName) => newExpr + case x => x + } + } else { + resultExprs :+ newExpr + } + }.flatMap { case (name, expr) => Seq(Literal(name), expr) } + + val expr = CreateNamedStruct(newExprs) + if (structExpr.nullable) { + If(IsNull(structExpr), Literal(null, expr.dataType), expr) + } else { + expr + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index a1277217b1b3a..3d9612018aaf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -686,13 +686,13 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(timeExp[, format]) - Returns the UNIX timestamp of the given time.", + usage = "_FUNC_(timeExp[, fmt]) - Returns the UNIX timestamp of the given time.", arguments = """ Arguments: * timeExp - A date/timestamp or string which is returned as a UNIX timestamp. - * format - Date/time format pattern to follow. Ignored if `timeExp` is not a string. - Default value is "yyyy-MM-dd HH:mm:ss". See Datetime Patterns - for valid date and time format patterns. + * fmt - Date/time format pattern to follow. Ignored if `timeExp` is not a string. + Default value is "yyyy-MM-dd HH:mm:ss". See Datetime Patterns + for valid date and time format patterns. """, examples = """ Examples: @@ -734,13 +734,13 @@ case class ToUnixTimestamp( * second parameter. */ @ExpressionDescription( - usage = "_FUNC_([timeExp[, format]]) - Returns the UNIX timestamp of current or specified time.", + usage = "_FUNC_([timeExp[, fmt]]) - Returns the UNIX timestamp of current or specified time.", arguments = """ Arguments: * timeExp - A date/timestamp or string. If not provided, this defaults to current time. - * format - Date/time format pattern to follow. Ignored if `timeExp` is not a string. - Default value is "yyyy-MM-dd HH:mm:ss". See Datetime Patterns - for valid date and time format patterns. + * fmt - Date/time format pattern to follow. Ignored if `timeExp` is not a string. + Default value is "yyyy-MM-dd HH:mm:ss". See Datetime Patterns + for valid date and time format patterns. """, examples = """ Examples: @@ -891,16 +891,16 @@ abstract class UnixTime extends ToTimestamp { * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string * representing the timestamp of that moment in the current system time zone in the given * format. If the format is missing, using format like "1970-01-01 00:00:00". - * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. + * Note that Hive Language Manual says it returns 0 if fail, but in fact it returns null. */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(unix_time, format) - Returns `unix_time` in the specified `format`.", + usage = "_FUNC_(unix_time, fmt) - Returns `unix_time` in the specified `fmt`.", arguments = """ Arguments: * unix_time - UNIX Timestamp to be converted to the provided format. - * format - Date/time format pattern to follow. See Datetime Patterns - for valid date and time format patterns. + * fmt - Date/time format pattern to follow. See Datetime Patterns + for valid date and time format patterns. """, examples = """ Examples: @@ -1176,41 +1176,15 @@ case class DateAddInterval( copy(timeZoneId = Option(timeZoneId)) } -/** - * This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function - * takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in UTC, and - * renders that timestamp as a timestamp in the given time zone. - * - * However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not - * timezone-agnostic. So in Spark this function just shift the timestamp value from UTC timezone to - * the given timezone. - * - * This function may return confusing result if the input is a string with timezone, e.g. - * '2018-03-13T06:18:23+00:00'. The reason is that, Spark firstly cast the string to timestamp - * according to the timezone in the string, and finally display the result by converting the - * timestamp to string according to the session local timezone. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(timestamp, timezone) - Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14 03:40:00.0'.", - examples = """ - Examples: - > SELECT _FUNC_('2016-08-31', 'Asia/Seoul'); - 2016-08-31 09:00:00 - """, - group = "datetime_funcs", - since = "1.5.0") -// scalastyle:on line.size.limit -case class FromUTCTimestamp(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { +sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { + val func: (Long, String) => Long + val funcName: String override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) override def dataType: DataType = TimestampType - override def prettyName: String = "from_utc_timestamp" override def nullSafeEval(time: Any, timezone: Any): Any = { - DateTimeUtils.fromUTCTime(time.asInstanceOf[Long], - timezone.asInstanceOf[UTF8String].toString) + func(time.asInstanceOf[Long], timezone.asInstanceOf[UTF8String].toString) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -1229,24 +1203,90 @@ case class FromUTCTimestamp(left: Expression, right: Expression) val tzTerm = ctx.addMutableState(tzClass, "tz", v => s"""$v = $dtu.getZoneId("$escapedTz");""") val utcTerm = "java.time.ZoneOffset.UTC" + val (fromTz, toTz) = this match { + case _: FromUTCTimestamp => (utcTerm, tzTerm) + case _: ToUTCTimestamp => (tzTerm, utcTerm) + } val eval = left.genCode(ctx) ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; |if (!${ev.isNull}) { - | ${ev.value} = $dtu.convertTz(${eval.value}, $utcTerm, $tzTerm); + | ${ev.value} = $dtu.convertTz(${eval.value}, $fromTz, $toTz); |} """.stripMargin) } } else { defineCodeGen(ctx, ev, (timestamp, format) => { - s"""$dtu.fromUTCTime($timestamp, $format.toString())""" + s"""$dtu.$funcName($timestamp, $format.toString())""" }) } } } +/** + * This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function + * takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in UTC, and + * renders that timestamp as a timestamp in the given time zone. + * + * However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not + * timezone-agnostic. So in Spark this function just shift the timestamp value from UTC timezone to + * the given timezone. + * + * This function may return confusing result if the input is a string with timezone, e.g. + * '2018-03-13T06:18:23+00:00'. The reason is that, Spark firstly cast the string to timestamp + * according to the timezone in the string, and finally display the result by converting the + * timestamp to string according to the session local timezone. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(timestamp, timezone) - Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14 03:40:00.0'.", + examples = """ + Examples: + > SELECT _FUNC_('2016-08-31', 'Asia/Seoul'); + 2016-08-31 09:00:00 + """, + group = "datetime_funcs", + since = "1.5.0") +// scalastyle:on line.size.limit +case class FromUTCTimestamp(left: Expression, right: Expression) extends UTCTimestamp { + override val func = DateTimeUtils.fromUTCTime + override val funcName: String = "fromUTCTime" + override val prettyName: String = "from_utc_timestamp" +} + +/** + * This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function + * takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in the given + * timezone, and renders that timestamp as a timestamp in UTC. + * + * However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not + * timezone-agnostic. So in Spark this function just shift the timestamp value from the given + * timezone to UTC timezone. + * + * This function may return confusing result if the input is a string with timezone, e.g. + * '2018-03-13T06:18:23+00:00'. The reason is that, Spark firstly cast the string to timestamp + * according to the timezone in the string, and finally display the result by converting the + * timestamp to string according to the session local timezone. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(timestamp, timezone) - Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield '2017-07-14 01:40:00.0'.", + examples = """ + Examples: + > SELECT _FUNC_('2016-08-31', 'Asia/Seoul'); + 2016-08-30 15:00:00 + """, + group = "datetime_funcs", + since = "1.5.0") +// scalastyle:on line.size.limit +case class ToUTCTimestamp(left: Expression, right: Expression) extends UTCTimestamp { + override val func = DateTimeUtils.toUTCTime + override val funcName: String = "toUTCTime" + override val prettyName: String = "to_utc_timestamp" +} + /** * Returns the date that is num_months after start_date. */ @@ -1349,77 +1389,6 @@ case class MonthsBetween( override def prettyName: String = "months_between" } -/** - * This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function - * takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in the given - * timezone, and renders that timestamp as a timestamp in UTC. - * - * However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not - * timezone-agnostic. So in Spark this function just shift the timestamp value from the given - * timezone to UTC timezone. - * - * This function may return confusing result if the input is a string with timezone, e.g. - * '2018-03-13T06:18:23+00:00'. The reason is that, Spark firstly cast the string to timestamp - * according to the timezone in the string, and finally display the result by converting the - * timestamp to string according to the session local timezone. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(timestamp, timezone) - Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield '2017-07-14 01:40:00.0'.", - examples = """ - Examples: - > SELECT _FUNC_('2016-08-31', 'Asia/Seoul'); - 2016-08-30 15:00:00 - """, - group = "datetime_funcs", - since = "1.5.0") -// scalastyle:on line.size.limit -case class ToUTCTimestamp(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { - - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) - override def dataType: DataType = TimestampType - override def prettyName: String = "to_utc_timestamp" - - override def nullSafeEval(time: Any, timezone: Any): Any = { - DateTimeUtils.toUTCTime(time.asInstanceOf[Long], - timezone.asInstanceOf[UTF8String].toString) - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - if (right.foldable) { - val tz = right.eval().asInstanceOf[UTF8String] - if (tz == null) { - ev.copy(code = code""" - |boolean ${ev.isNull} = true; - |long ${ev.value} = 0; - """.stripMargin) - } else { - val tzClass = classOf[ZoneId].getName - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val escapedTz = StringEscapeUtils.escapeJava(tz.toString) - val tzTerm = ctx.addMutableState(tzClass, "tz", - v => s"""$v = $dtu.getZoneId("$escapedTz");""") - val utcTerm = "java.time.ZoneOffset.UTC" - val eval = left.genCode(ctx) - ev.copy(code = code""" - |${eval.code} - |boolean ${ev.isNull} = ${eval.isNull}; - |long ${ev.value} = 0; - |if (!${ev.isNull}) { - | ${ev.value} = $dtu.convertTz(${eval.value}, $tzTerm, $utcTerm); - |} - """.stripMargin) - } - } else { - defineCodeGen(ctx, ev, (timestamp, format) => { - s"""$dtu.toUTCTime($timestamp, $format.toString())""" - }) - } - } -} - /** * Parses a column to a date based on the given format. */ @@ -1450,8 +1419,7 @@ case class ParseToDate(left: Expression, format: Option[Expression], child: Expr extends RuntimeReplaceable { def this(left: Expression, format: Expression) { - this(left, Option(format), - Cast(SecondsToTimestamp(UnixTimestamp(left, format)), DateType)) + this(left, Option(format), Cast(GetTimestamp(left, format), DateType)) } def this(left: Expression) = { @@ -2172,4 +2140,3 @@ case class SubtractDates(left: Expression, right: Expression) }) } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d5de95c65e49e..361bcd492965b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -678,6 +678,13 @@ object MapObjects { elementType: DataType, elementNullable: Boolean = true, customCollectionCls: Option[Class[_]] = None): MapObjects = { + // UnresolvedMapObjects does not serialize its 'function' field. + // If an array expression or array Encoder is not correctly resolved before + // serialization, this exception condition may occur. + require(function != null, + "MapObjects applied with a null function. " + + "Likely cause is failure to resolve an array expression or encoder. " + + "(See UnresolvedMapObjects)") val loopVar = LambdaVariable("MapObject", elementType, elementNullable) MapObjects(loopVar, function(loopVar), inputData, customCollectionCls) } @@ -734,7 +741,7 @@ case class MapObjects private( case ObjectType(cls) if cls.isArray => _.asInstanceOf[Array[_]].toSeq case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => - _.asInstanceOf[java.util.List[_]].asScala + _.asInstanceOf[java.util.List[_]].asScala.toSeq case ObjectType(cls) if cls == classOf[Object] => (inputCollection) => { if (inputCollection.getClass.isArray) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 8bf1f19844556..d950fef3b26a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -124,7 +124,7 @@ package object expressions { } private def unique[T](m: Map[T, Seq[Attribute]]): Map[T, Seq[Attribute]] = { - m.mapValues(_.distinct).map(identity) + m.mapValues(_.distinct).toMap } /** Map to use for direct case insensitive attribute lookups. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index f46a1c6836fcf..ff8856708c6d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -186,7 +186,7 @@ object SubExprUtils extends PredicateHelper { e } } - outerExpressions + outerExpressions.toSeq } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 70a673bb42457..c145f26472355 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -168,10 +168,10 @@ private[sql] class JSONOptionsInRead( } protected override def checkedEncoding(enc: String): String = { - val isBlacklisted = JSONOptionsInRead.blacklist.contains(Charset.forName(enc)) - require(multiLine || !isBlacklisted, - s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled: - |Blacklist: ${JSONOptionsInRead.blacklist.mkString(", ")}""".stripMargin) + val isDenied = JSONOptionsInRead.denyList.contains(Charset.forName(enc)) + require(multiLine || !isDenied, + s"""The $enc encoding must not be included in the denyList when multiLine is disabled: + |denylist: ${JSONOptionsInRead.denyList.mkString(", ")}""".stripMargin) val isLineSepRequired = multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty @@ -188,7 +188,7 @@ private[sql] object JSONOptionsInRead { // only the first lines will have the BOM which leads to impossibility for reading // the rest lines. Besides of that, the lineSep option must have the BOM in such // encodings which can never present between lines. - val blacklist = Seq( + val denyList = Seq( Charset.forName("UTF-16"), Charset.forName("UTF-32") ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index f79dabf758c14..1c33a2c7c3136 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -39,7 +39,18 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { // Remove redundant field extraction. case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => createNamedStruct.valExprs(ordinal) - + case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) => + val name = w.dataType(ordinal).name + val matches = names.zip(valExprs).filter(_._1 == name) + if (matches.nonEmpty) { + // return last matching element as that is the final value for the field being extracted. + // For example, if a user submits a query like this: + // `$"struct_col".withField("b", lit(1)).withField("b", lit(2)).getField("b")` + // we want to return `lit(2)` (and not `lit(1)`). + matches.last._2 + } else { + GetStructField(struct, ordinal, maybeName) + } // Remove redundant array indexing. case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) => // Instead of selecting the field on the entire array, select it from each member diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index f92d8f5b8e534..c450ea891a612 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -162,7 +162,7 @@ object JoinReorderDP extends PredicateHelper with Logging { val topOutputSet = AttributeSet(output) while (foundPlans.size < items.length) { // Build plans for the next level. - foundPlans += searchLevel(foundPlans, conf, conditions, topOutputSet, filters) + foundPlans += searchLevel(foundPlans.toSeq, conf, conditions, topOutputSet, filters) } val durationInMs = (System.nanoTime() - startTime) / (1000 * 1000) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index 118f41f9cd232..0c8666b72cace 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -149,10 +149,12 @@ object NestedColumnAliasing { case _ => false } + // Note that when we group by extractors with their references, we should remove + // cosmetic variations. val exclusiveAttrSet = AttributeSet(exclusiveAttrs ++ otherRootReferences) val aliasSub = nestedFieldReferences.asInstanceOf[Seq[ExtractValue]] .filter(!_.references.subsetOf(exclusiveAttrSet)) - .groupBy(_.references.head) + .groupBy(_.references.head.canonicalized.asInstanceOf[Attribute]) .flatMap { case (attr, nestedFields: Seq[ExtractValue]) => // Remove redundant `ExtractValue`s if they share the same parent nest field. // For example, when `a.b` and `a.b.c` are in project list, we only need to alias `a.b`. @@ -174,9 +176,12 @@ object NestedColumnAliasing { // If all nested fields of `attr` are used, we don't need to introduce new aliases. // By default, ColumnPruning rule uses `attr` already. + // Note that we need to remove cosmetic variations first, so we only count a + // nested field once. if (nestedFieldToAlias.nonEmpty && - nestedFieldToAlias - .map { case (nestedField, _) => totalFieldNum(nestedField.dataType) } + dedupNestedFields.map(_.canonicalized) + .distinct + .map { nestedField => totalFieldNum(nestedField.dataType) } .sum < totalFieldNum(attr.dataType)) { Some(attr.exprId -> nestedFieldToAlias) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 8d5dbc7dc90eb..10f846cf910f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window} @@ -119,6 +119,15 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case _ if expr.dataType == FloatType || expr.dataType == DoubleType => KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) + case If(cond, trueValue, falseValue) => + If(cond, normalize(trueValue), normalize(falseValue)) + + case CaseWhen(branches, elseVale) => + CaseWhen(branches.map(br => (br._1, normalize(br._2))), elseVale.map(normalize)) + + case Coalesce(children) => + Coalesce(children.map(normalize)) + case _ if expr.dataType.isInstanceOf[StructType] => val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i => normalize(GetStructField(expr, i)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e800ee3b93f51..33da482c4eea4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -48,7 +48,7 @@ abstract class Optimizer(catalogManager: CatalogManager) plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty) } - override protected val blacklistedOnceBatches: Set[String] = + override protected val excludedOnceBatches: Set[String] = Set( "PartitionPruning", "Extract Python UDFs", @@ -107,6 +107,7 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateSerialization, RemoveRedundantAliases, RemoveNoopOperators, + CombineWithFields, SimplifyExtractValueOps, CombineConcats) ++ extendedOperatorOptimizationRules @@ -160,7 +161,10 @@ abstract class Optimizer(catalogManager: CatalogManager) // LocalRelation and does not trigger many rules. Batch("LocalRelation early", fixedPoint, ConvertToLocalRelation, - PropagateEmptyRelation) :: + PropagateEmptyRelation, + // PropagateEmptyRelation can change the nullability of an attribute from nullable to + // non-nullable when an empty relation child of a Union is removed + UpdateAttributeNullability) :: Batch("Pullup Correlated Expressions", Once, PullupCorrelatedPredicates) :: // Subquery batch applies the optimizer rules recursively. Therefore, it makes no sense @@ -197,7 +201,10 @@ abstract class Optimizer(catalogManager: CatalogManager) ReassignLambdaVariableID) :+ Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, - PropagateEmptyRelation) :+ + PropagateEmptyRelation, + // PropagateEmptyRelation can change the nullability of an attribute from nullable to + // non-nullable when an empty relation child of a Union is removed + UpdateAttributeNullability) :+ // The following batch should be executed after batch "Join Reorder" and "LocalRelation". Batch("Check Cartesian Products", Once, CheckCartesianProducts) :+ @@ -207,7 +214,8 @@ abstract class Optimizer(catalogManager: CatalogManager) CollapseProject, RemoveNoopOperators) :+ // This batch must be executed after the `RewriteSubquery` batch, which creates joins. - Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) + Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ + Batch("ReplaceWithFieldsExpression", Once, ReplaceWithFieldsExpression) // remove any batches with no rules. this may happen when subclasses do not add optional rules. batches.filter(_.rules.nonEmpty) @@ -240,7 +248,8 @@ abstract class Optimizer(catalogManager: CatalogManager) PullupCorrelatedPredicates.ruleName :: RewriteCorrelatedScalarSubquery.ruleName :: RewritePredicateSubquery.ruleName :: - NormalizeFloatingNumbers.ruleName :: Nil + NormalizeFloatingNumbers.ruleName :: + ReplaceWithFieldsExpression.ruleName :: Nil /** * Optimize all the subqueries inside expression. @@ -364,38 +373,38 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { /** * Remove the top-level alias from an expression when it is redundant. */ - private def removeRedundantAlias(e: Expression, blacklist: AttributeSet): Expression = e match { + private def removeRedundantAlias(e: Expression, excludeList: AttributeSet): Expression = e match { // Alias with metadata can not be stripped, or the metadata will be lost. // If the alias name is different from attribute name, we can't strip it either, or we // may accidentally change the output schema name of the root plan. case a @ Alias(attr: Attribute, name) if a.metadata == Metadata.empty && name == attr.name && - !blacklist.contains(attr) && - !blacklist.contains(a) => + !excludeList.contains(attr) && + !excludeList.contains(a) => attr case a => a } /** - * Remove redundant alias expression from a LogicalPlan and its subtree. A blacklist is used to - * prevent the removal of seemingly redundant aliases used to deduplicate the input for a (self) - * join or to prevent the removal of top-level subquery attributes. + * Remove redundant alias expression from a LogicalPlan and its subtree. A set of excludes is used + * to prevent the removal of seemingly redundant aliases used to deduplicate the input for a + * (self) join or to prevent the removal of top-level subquery attributes. */ - private def removeRedundantAliases(plan: LogicalPlan, blacklist: AttributeSet): LogicalPlan = { + private def removeRedundantAliases(plan: LogicalPlan, excluded: AttributeSet): LogicalPlan = { plan match { // We want to keep the same output attributes for subqueries. This means we cannot remove // the aliases that produce these attributes case Subquery(child, correlated) => - Subquery(removeRedundantAliases(child, blacklist ++ child.outputSet), correlated) + Subquery(removeRedundantAliases(child, excluded ++ child.outputSet), correlated) // A join has to be treated differently, because the left and the right side of the join are - // not allowed to use the same attributes. We use a blacklist to prevent us from creating a - // situation in which this happens; the rule will only remove an alias if its child + // not allowed to use the same attributes. We use an exclude list to prevent us from creating + // a situation in which this happens; the rule will only remove an alias if its child // attribute is not on the black list. case Join(left, right, joinType, condition, hint) => - val newLeft = removeRedundantAliases(left, blacklist ++ right.outputSet) - val newRight = removeRedundantAliases(right, blacklist ++ newLeft.outputSet) + val newLeft = removeRedundantAliases(left, excluded ++ right.outputSet) + val newRight = removeRedundantAliases(right, excluded ++ newLeft.outputSet) val mapping = AttributeMap( createAttributeMapping(left, newLeft) ++ createAttributeMapping(right, newRight)) @@ -408,7 +417,7 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { // Remove redundant aliases in the subtree(s). val currentNextAttrPairs = mutable.Buffer.empty[(Attribute, Attribute)] val newNode = plan.mapChildren { child => - val newChild = removeRedundantAliases(child, blacklist) + val newChild = removeRedundantAliases(child, excluded) currentNextAttrPairs ++= createAttributeMapping(child, newChild) newChild } @@ -416,14 +425,14 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { // Create the attribute mapping. Note that the currentNextAttrPairs can contain duplicate // keys in case of Union (this is caused by the PushProjectionThroughUnion rule); in this // case we use the first mapping (which should be provided by the first child). - val mapping = AttributeMap(currentNextAttrPairs) + val mapping = AttributeMap(currentNextAttrPairs.toSeq) // Create a an expression cleaning function for nodes that can actually produce redundant // aliases, use identity otherwise. val clean: Expression => Expression = plan match { - case _: Project => removeRedundantAlias(_, blacklist) - case _: Aggregate => removeRedundantAlias(_, blacklist) - case _: Window => removeRedundantAlias(_, blacklist) + case _: Project => removeRedundantAlias(_, excluded) + case _: Aggregate => removeRedundantAlias(_, excluded) + case _: Window => removeRedundantAlias(_, excluded) case _ => identity[Expression] } @@ -931,7 +940,7 @@ object CombineUnions extends Rule[LogicalPlan] { flattened += child } } - Union(flattened) + Union(flattened.toSeq) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index b19e13870aa65..0299646150ff3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -50,8 +50,26 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper wit override def conf: SQLConf = SQLConf.get def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case p: Union if p.children.forall(isEmptyLocalRelation) => - empty(p) + case p @ Union(children) if children.exists(isEmptyLocalRelation) => + val newChildren = children.filterNot(isEmptyLocalRelation) + if (newChildren.isEmpty) { + empty(p) + } else { + val newPlan = if (newChildren.size > 1) Union(newChildren) else newChildren.head + val outputs = newPlan.output.zip(p.output) + // the original Union may produce different output attributes than the new one so we alias + // them if needed + if (outputs.forall { case (newAttr, oldAttr) => newAttr.exprId == oldAttr.exprId }) { + newPlan + } else { + val outputAliases = outputs.map { case (newAttr, oldAttr) => + val newExplicitMetadata = + if (oldAttr.metadata != newAttr.metadata) Some(oldAttr.metadata) else None + Alias(newAttr, oldAttr.name)(oldAttr.exprId, explicitMetadata = newExplicitMetadata) + } + Project(outputAliases, newPlan) + } + } // Joins on empty LocalRelations generated from streaming sources are not eliminated // as stateful streaming joins need to perform other state management operations other than diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala new file mode 100644 index 0000000000000..05c90864e4bb0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.WithFields +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + + +/** + * Combines all adjacent [[WithFields]] expression into a single [[WithFields]] expression. + */ +object CombineWithFields extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) => + WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2) + } +} + +/** + * Replaces [[WithFields]] expression with an evaluable expression. + */ +object ReplaceWithFieldsExpression extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case w: WithFields => w.evalExpr + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index bd400f86ea2c1..773ee7708aea3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -677,7 +677,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { } /** - * Whitelist of all [[UnaryNode]]s for which allow foldable propagation. + * List of all [[UnaryNode]]s which allow foldable propagation. */ private def canPropagateFoldables(u: UnaryNode): Boolean = u match { case _: Project => true @@ -765,7 +765,7 @@ object CombineConcats extends Rule[LogicalPlan] { flattened += child } } - Concat(flattened) + Concat(flattened.toSeq) } private def hasNestedConcats(concat: Concat): Boolean = concat.children.exists { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 6fdd2110ab12a..7b696912aa465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -478,11 +478,11 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { while (true) { bottomPart match { case havingPart @ Filter(_, aggPart: Aggregate) => - return (topPart, Option(havingPart), aggPart) + return (topPart.toSeq, Option(havingPart), aggPart) case aggPart: Aggregate => // No HAVING clause - return (topPart, None, aggPart) + return (topPart.toSeq, None, aggPart) case p @ Project(_, child) => topPart += p diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d08bcb1420176..29621e11e534c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -141,7 +141,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging s"CTE definition can't have duplicate names: ${duplicates.mkString("'", "', '", "'")}.", ctx) } - With(plan, ctes) + With(plan, ctes.toSeq) } /** @@ -182,7 +182,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging if (selects.length == 1) { selects.head } else { - Union(selects) + Union(selects.toSeq) } } @@ -229,7 +229,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging if (inserts.length == 1) { inserts.head } else { - Union(inserts) + Union(inserts.toSeq) } } @@ -389,7 +389,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging assignCtx.assignment().asScala.map { assign => Assignment(UnresolvedAttribute(visitMultipartIdentifier(assign.key)), expression(assign.value)) - } + }.toSeq } override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) { @@ -444,7 +444,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging throw new ParseException("The number of inserted values cannot match the fields.", clause.notMatchedAction()) } - InsertAction(condition, columns.zip(values).map(kv => Assignment(kv._1, kv._2))) + InsertAction(condition, columns.zip(values).map(kv => Assignment(kv._1, kv._2)).toSeq) } } else { // It should not be here. @@ -473,8 +473,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging aliasedTarget, aliasedSource, mergeCondition, - matchedActions, - notMatchedActions) + matchedActions.toSeq, + notMatchedActions.toSeq) } /** @@ -490,7 +490,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging // Before calling `toMap`, we check duplicated keys to avoid silently ignore partition values // in partition spec like PARTITION(a='1', b='2', a='3'). The real semantical check for // partition columns will be done in analyzer. - checkDuplicateKeys(parts, ctx) + checkDuplicateKeys(parts.toSeq, ctx) parts.toMap } @@ -530,17 +530,17 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val withOrder = if ( !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { // ORDER BY ... - Sort(order.asScala.map(visitSortItem), global = true, query) + Sort(order.asScala.map(visitSortItem).toSeq, global = true, query) } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { // SORT BY ... - Sort(sort.asScala.map(visitSortItem), global = false, query) + Sort(sort.asScala.map(visitSortItem).toSeq, global = false, query) } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { // DISTRIBUTE BY ... withRepartitionByExpression(ctx, expressionList(distributeBy), query) } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { // SORT BY ... DISTRIBUTE BY ... Sort( - sort.asScala.map(visitSortItem), + sort.asScala.map(visitSortItem).toSeq, global = false, withRepartitionByExpression(ctx, expressionList(distributeBy), query)) } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { @@ -841,7 +841,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging // Note that mapValues creates a view instead of materialized map. We force materialization by // mapping over identity. - WithWindowDefinition(windowMapView.map(identity), query) + WithWindowDefinition(windowMapView.map(identity).toMap, query) } /** @@ -856,8 +856,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging if (ctx.GROUPING != null) { // GROUP BY .... GROUPING SETS (...) val selectedGroupByExprs = - ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e))) - GroupingSets(selectedGroupByExprs, groupByExpressions, query, selectExpressions) + ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)).toSeq) + GroupingSets(selectedGroupByExprs.toSeq, groupByExpressions, query, selectExpressions) } else { // GROUP BY .... (WITH CUBE | WITH ROLLUP)? val mappedGroupByExpressions = if (ctx.CUBE != null) { @@ -878,8 +878,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging ctx: HintContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { var plan = query - ctx.hintStatements.asScala.reverse.foreach { case stmt => - plan = UnresolvedHint(stmt.hintName.getText, stmt.parameters.asScala.map(expression), plan) + ctx.hintStatements.asScala.reverse.foreach { stmt => + plan = UnresolvedHint(stmt.hintName.getText, + stmt.parameters.asScala.map(expression).toSeq, plan) } plan } @@ -898,10 +899,10 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } else { CreateStruct( ctx.pivotColumn.identifiers.asScala.map( - identifier => UnresolvedAttribute.quoted(identifier.getText))) + identifier => UnresolvedAttribute.quoted(identifier.getText)).toSeq) } val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue) - Pivot(None, pivotColumn, pivotValues, aggregates, query) + Pivot(None, pivotColumn, pivotValues.toSeq, aggregates, query) } /** @@ -930,7 +931,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging // scalastyle:off caselocale Some(ctx.tblName.getText.toLowerCase), // scalastyle:on caselocale - ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply), + ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply).toSeq, query) } @@ -1081,7 +1082,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } val tvf = UnresolvedTableValuedFunction( - func.funcName.getText, func.expression.asScala.map(expression), aliases) + func.funcName.getText, func.expression.asScala.map(expression).toSeq, aliases) tvf.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan) } @@ -1106,7 +1107,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging Seq.tabulate(rows.head.size)(i => s"col${i + 1}") } - val table = UnresolvedInlineTable(aliases, rows) + val table = UnresolvedInlineTable(aliases, rows.toSeq) table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan) } @@ -1180,7 +1181,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Create a Sequence of Strings for an identifier list. */ override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) { - ctx.ident.asScala.map(_.getText) + ctx.ident.asScala.map(_.getText).toSeq } /* ******************************************************************************************** @@ -1205,10 +1206,10 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging /** * Create a multi-part identifier. */ - override def visitMultipartIdentifier( - ctx: MultipartIdentifierContext): Seq[String] = withOrigin(ctx) { - ctx.parts.asScala.map(_.getText) - } + override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] = + withOrigin(ctx) { + ctx.parts.asScala.map(_.getText).toSeq + } /* ******************************************************************************************** * Expression parsing @@ -1223,7 +1224,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Create sequence of expressions from the given sequence of contexts. */ private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = { - trees.asScala.map(expression) + trees.asScala.map(expression).toSeq } /** @@ -1231,7 +1232,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Both un-targeted (global) and targeted aliases are supported. */ override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) { - UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText))) + UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText).toSeq)) } /** @@ -1387,7 +1388,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging if (expressions.isEmpty) { throw new ParseException("Expected something between '(' and ')'.", ctx) } else { - expressions.asScala.map(expression).map(p => invertIfNotDefined(new Like(e, p))) + expressions.asScala.map(expression).map(p => invertIfNotDefined(new Like(e, p))).toSeq } } @@ -1401,7 +1402,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.IN if ctx.query != null => invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query)))) case SqlBaseParser.IN => - invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) + invertIfNotDefined(In(e, ctx.expression.asScala.map(expression).toSeq)) case SqlBaseParser.LIKE => Option(ctx.quantifier).map(_.getType) match { case Some(SqlBaseParser.ANY) | Some(SqlBaseParser.SOME) => @@ -1526,7 +1527,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Create a [[CreateStruct]] expression. */ override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) { - CreateStruct.create(ctx.argument.asScala.map(expression)) + CreateStruct.create(ctx.argument.asScala.map(expression).toSeq) } /** @@ -1617,7 +1618,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging // Transform COUNT(*) into COUNT(1). Seq(Literal(1)) case expressions => - expressions + expressions.toSeq } val filter = Option(ctx.where).map(expression(_)) val function = UnresolvedFunction( @@ -1639,14 +1640,14 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * This is used in CREATE FUNCTION, DROP FUNCTION, SHOWFUNCTIONS. */ protected def visitFunctionName(ctx: MultipartIdentifierContext): FunctionIdentifier = { - visitFunctionName(ctx, ctx.parts.asScala.map(_.getText)) + visitFunctionName(ctx, ctx.parts.asScala.map(_.getText).toSeq) } /** * Create a function database (optional) and name pair. */ protected def visitFunctionName(ctx: QualifiedNameContext): FunctionIdentifier = { - visitFunctionName(ctx, ctx.identifier().asScala.map(_.getText)) + visitFunctionName(ctx, ctx.identifier().asScala.map(_.getText).toSeq) } /** @@ -1682,7 +1683,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val function = expression(ctx.expression).transformUp { case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts) } - LambdaFunction(function, arguments) + LambdaFunction(function, arguments.toSeq) } /** @@ -1714,8 +1715,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } WindowSpecDefinition( - partition, - order, + partition.toSeq, + order.toSeq, frameSpecOption.getOrElse(UnspecifiedFrame)) } @@ -1747,7 +1748,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Create a [[CreateStruct]] expression. */ override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { - CreateStruct(ctx.namedExpression().asScala.map(expression)) + CreateStruct(ctx.namedExpression().asScala.map(expression).toSeq) } /** @@ -1773,7 +1774,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val branches = ctx.whenClause.asScala.map { wCtx => (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) } - CaseWhen(branches, Option(ctx.elseExpression).map(expression)) + CaseWhen(branches.toSeq, Option(ctx.elseExpression).map(expression)) } /** @@ -1792,7 +1793,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val branches = ctx.whenClause.asScala.map { wCtx => (expression(wCtx.condition), expression(wCtx.result)) } - CaseWhen(branches, Option(ctx.elseExpression).map(expression)) + CaseWhen(branches.toSeq, Option(ctx.elseExpression).map(expression)) } /** @@ -2030,6 +2031,15 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging Long.MinValue, Long.MaxValue, LongType.simpleString)(_.toLong) } + /** + * Create a Float Literal expression. + */ + override def visitFloatLiteral(ctx: FloatLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Float.MinValue, Float.MaxValue, FloatType.simpleString)(_.toFloat) + } + /** * Create a Double Literal expression. */ @@ -2203,6 +2213,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging DecimalType(precision.getText.toInt, 0) case ("decimal" | "dec" | "numeric", precision :: scale :: Nil) => DecimalType(precision.getText.toInt, scale.getText.toInt) + case ("void", Nil) => NullType case ("interval", Nil) => CalendarIntervalType case (dt, params) => val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt @@ -2235,7 +2246,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Create a [[StructType]] from a number of column definitions. */ override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) { - ctx.colType().asScala.map(visitColType) + ctx.colType().asScala.map(visitColType).toSeq } /** @@ -2276,7 +2287,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitComplexColTypeList( ctx: ComplexColTypeListContext): Seq[StructField] = withOrigin(ctx) { - ctx.complexColType().asScala.map(visitComplexColType) + ctx.complexColType().asScala.map(visitComplexColType).toSeq } /** @@ -2352,7 +2363,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging key -> value } // Check for duplicate property names. - checkDuplicateKeys(properties, ctx) + checkDuplicateKeys(properties.toSeq, ctx) properties.toMap } @@ -2433,7 +2444,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging if (temporary && ifNotExists) { operationNotAllowed("CREATE TEMPORARY TABLE ... IF NOT EXISTS", ctx) } - val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText) + val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText).toSeq (multipartIdentifier, temporary, ifNotExists, ctx.EXTERNAL != null) } @@ -2442,7 +2453,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitReplaceTableHeader( ctx: ReplaceTableHeaderContext): TableHeader = withOrigin(ctx) { - val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText) + val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText).toSeq (multipartIdentifier, false, false, false) } @@ -2450,7 +2461,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Parse a qualified name to a multipart name. */ override def visitQualifiedName(ctx: QualifiedNameContext): Seq[String] = withOrigin(ctx) { - ctx.identifier.asScala.map(_.getText) + ctx.identifier.asScala.map(_.getText).toSeq } /** @@ -2488,7 +2499,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging IdentityTransform(FieldReference(typedVisit[Seq[String]](identityCtx.qualifiedName))) case applyCtx: ApplyTransformContext => - val arguments = applyCtx.argument.asScala.map(visitTransformArgument) + val arguments = applyCtx.argument.asScala.map(visitTransformArgument).toSeq applyCtx.identifier.getText match { case "bucket" => @@ -2505,7 +2516,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val fields = arguments.tail.map(arg => getFieldReference(applyCtx, arg)) - BucketTransform(LiteralValue(numBuckets, IntegerType), fields) + BucketTransform(LiteralValue(numBuckets, IntegerType), fields.toSeq) case "years" => YearsTransform(getSingleFieldReference(applyCtx, arguments)) @@ -2522,7 +2533,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case name => ApplyTransform(name, arguments) } - } + }.toSeq } /** @@ -2946,7 +2957,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging override def visitAddTableColumns(ctx: AddTableColumnsContext): LogicalPlan = withOrigin(ctx) { AlterTableAddColumnsStatement( visitMultipartIdentifier(ctx.multipartIdentifier), - ctx.columns.qualifiedColTypeWithPosition.asScala.map(typedVisit[QualifiedColType]) + ctx.columns.qualifiedColTypeWithPosition.asScala.map(typedVisit[QualifiedColType]).toSeq ) } @@ -2962,7 +2973,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging ctx: RenameTableColumnContext): LogicalPlan = withOrigin(ctx) { AlterTableRenameColumnStatement( visitMultipartIdentifier(ctx.table), - ctx.from.parts.asScala.map(_.getText), + ctx.from.parts.asScala.map(_.getText).toSeq, ctx.to.getText) } @@ -3074,7 +3085,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging "Column position is not supported in Hive-style REPLACE COLUMNS") } typedVisit[QualifiedColType](colType) - } + }.toSeq ) } @@ -3092,7 +3103,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val columnsToDrop = ctx.columns.multipartIdentifier.asScala.map(typedVisit[Seq[String]]) AlterTableDropColumnsStatement( visitMultipartIdentifier(ctx.multipartIdentifier), - columnsToDrop) + columnsToDrop.toSeq) } /** @@ -3165,7 +3176,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } else { DescribeColumnStatement( visitMultipartIdentifier(ctx.multipartIdentifier()), - ctx.describeColName.nameParts.asScala.map(_.getText), + ctx.describeColName.nameParts.asScala.map(_.getText).toSeq, isExtended) } } else { @@ -3401,7 +3412,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } AlterTableAddPartitionStatement( visitMultipartIdentifier(ctx.multipartIdentifier), - specsAndLocs, + specsAndLocs.toSeq, ctx.EXISTS != null) } @@ -3441,7 +3452,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } AlterTableDropPartitionStatement( visitMultipartIdentifier(ctx.multipartIdentifier), - ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec), + ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec).toSeq, ifExists = ctx.EXISTS != null, purge = ctx.PURGE != null, retainData = false) @@ -3636,7 +3647,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging CreateFunctionStatement( functionIdentifier, string(ctx.className), - resources, + resources.toSeq, ctx.TEMPORARY != null, ctx.EXISTS != null, ctx.REPLACE != null) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 039fd9382000a..f1a363cca752e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1028,7 +1028,7 @@ case class Deduplicate( /** * A trait to represent the commands that support subqueries. - * This is used to whitelist such commands in the subquery-related checks. + * This is used to allow such commands in the subquery-related checks. */ trait SupportsSubquery extends LogicalPlan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index e1dbef9ebeede..967ccedeeeacb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -344,7 +344,7 @@ object EstimationUtils { } } } - overlappedRanges + overlappedRanges.toSeq } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 19a0d1279cc32..777a4c8291223 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -323,7 +323,7 @@ case class JoinEstimation(join: Join) extends Logging { outputAttrStats += a -> newColStat } } - outputAttrStats + outputAttrStats.toSeq } private def extractJoinKeysWithColStats( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index bff04d317d4d2..2109e8f355c5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -80,8 +80,8 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { /** Defines a sequence of rule batches, to be overridden by the implementation. */ protected def batches: Seq[Batch] - /** Once batches that are blacklisted in the idempotence checker */ - protected val blacklistedOnceBatches: Set[String] = Set.empty + /** Once batches that are excluded in the idempotence checker */ + protected val excludedOnceBatches: Set[String] = Set.empty /** * Defines a check function that checks for structural integrity of the plan after the execution @@ -189,7 +189,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { } // Check idempotence for Once batches. if (batch.strategy == Once && - Utils.isTesting && !blacklistedOnceBatches.contains(batch.name)) { + Utils.isTesting && !excludedOnceBatches.contains(batch.name)) { checkBatchIdempotence(batch, curPlan) } continue = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index c4a106702a515..6cd062da2b94a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -185,7 +185,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { def map[A](f: BaseType => A): Seq[A] = { val ret = new collection.mutable.ArrayBuffer[A]() foreach(ret += f(_)) - ret + ret.toSeq } /** @@ -195,7 +195,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { def flatMap[A](f: BaseType => TraversableOnce[A]): Seq[A] = { val ret = new collection.mutable.ArrayBuffer[A]() foreach(ret ++= f(_)) - ret + ret.toSeq } /** @@ -206,7 +206,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { val ret = new collection.mutable.ArrayBuffer[B]() val lifted = pf.lift foreach(node => lifted(node).foreach(ret.+=)) - ret + ret.toSeq } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 1f88a700847de..711ef265c6cf7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -26,7 +26,7 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenericArrayData(val array: Array[Any]) extends ArrayData { def this(seq: Seq[Any]) = this(seq.toArray) - def this(list: java.util.List[Any]) = this(list.asScala) + def this(list: java.util.List[Any]) = this(list.asScala.toSeq) // TODO: This is boxing. We should specialize. def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index 3a0490d07733d..2797a40614504 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -223,7 +223,7 @@ class QuantileSummaries( otherIdx += 1 } - val comp = compressImmut(mergedSampled, 2 * mergedRelativeError * mergedCount) + val comp = compressImmut(mergedSampled.toIndexedSeq, 2 * mergedRelativeError * mergedCount) new QuantileSummaries(other.compressThreshold, mergedRelativeError, comp, mergedCount, true) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index e1f329352592f..1a3a7207c6ca9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.AlterTable import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.types.{ArrayType, DataType, HIVE_TYPE_STRING, HiveStringType, MapType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, HIVE_TYPE_STRING, HiveStringType, MapType, NullType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils @@ -346,4 +346,23 @@ private[sql] object CatalogV2Util { } } } + + def failNullType(dt: DataType): Unit = { + def containsNullType(dt: DataType): Boolean = dt match { + case ArrayType(et, _) => containsNullType(et) + case MapType(kt, vt, _) => containsNullType(kt) || containsNullType(vt) + case StructType(fields) => fields.exists(f => containsNullType(f.dataType)) + case _ => dt.isInstanceOf[NullType] + } + if (containsNullType(dt)) { + throw new AnalysisException( + s"Cannot create tables with ${NullType.simpleString} type.") + } + } + + def assertNoNullTypeInSchema(schema: StructType): Unit = { + schema.foreach { f => + failNullType(f.dataType) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3149d14c1ddcc..9be0497e46603 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2106,6 +2106,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val DISABLE_HINTS = + buildConf("spark.sql.optimizer.disableHints") + .internal() + .doc("When true, the optimizer will disable user-specified hints that are additional " + + "directives for better planning of a query.") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + val NESTED_PREDICATE_PUSHDOWN_FILE_SOURCE_LIST = buildConf("spark.sql.optimizer.nestedPredicatePushdown.supportedFileSources") .internal() @@ -2634,7 +2643,7 @@ object SQLConf { "when false, forbid the cast, more details in SPARK-31710") .version("3.1.0") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val COALESCE_BUCKETS_IN_SORT_MERGE_JOIN_ENABLED = buildConf("spark.sql.bucketing.coalesceBucketsInSortMergeJoin.enabled") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsMetadata.scala similarity index 62% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsMetadata.scala index 42631c90ebc55..b2cb19b009141 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsMetadata.scala @@ -14,19 +14,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -package org.apache.spark.sql.execution.streaming.continuous.shuffle - -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +package org.apache.spark.sql.internal.connector /** - * Trait for reading from a continuous processing shuffle. + * A mix-in interface for {@link FileScan}. This can be used to report metadata + * for a file based scan operator. This is currently used for supporting formatted + * explain. */ -trait ContinuousShuffleReader { - /** - * Returns an iterator over the incoming rows in an epoch. Implementations should block waiting - * for new rows to arrive, and end the iterator once they've received epoch markers from all - * shuffle writers. - */ - def read(): Iterator[UnsafeRow] +trait SupportsMetadata { + def getMetaData(): Map[String, String] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index bd2c1d5c26299..b14fb04cc4539 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -423,7 +423,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum override def simpleString: String = { - val fieldTypes = fields.view.map(field => s"${field.name}:${field.dataType.simpleString}") + val fieldTypes = fields.view.map(field => s"${field.name}:${field.dataType.simpleString}").toSeq truncatedString( fieldTypes, "struct<", ",", ">", @@ -542,7 +542,7 @@ object StructType extends AbstractDataType { def apply(fields: java.util.List[StructField]): StructType = { import scala.collection.JavaConverters._ - StructType(fields.asScala) + StructType(fields.asScala.toSeq) } private[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = @@ -606,7 +606,7 @@ object StructType extends AbstractDataType { newFields += f } - StructType(newFields) + StructType(newFields.toSeq) case (DecimalType.Fixed(leftPrecision, leftScale), DecimalType.Fixed(rightPrecision, rightScale)) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 003ce850c926e..c3bc67d76138a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -121,7 +121,7 @@ private[sql] object ArrowUtils { val dt = fromArrowField(child) StructField(child.getName, dt, child.isNullable) } - StructType(fields) + StructType(fields.toSeq) case arrowType => fromArrowType(arrowType) } } @@ -137,7 +137,7 @@ private[sql] object ArrowUtils { StructType(schema.getFields.asScala.map { field => val dt = fromArrowField(field) StructField(field.getName, dt, field.isNullable) - }) + }.toSeq) } /** Return Map with conf settings to be used in ArrowPythonRunner */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 6a5bdc4f6fc3d..9fb8b0f351d51 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -136,7 +136,7 @@ object RandomDataGenerator { } i += 1 } - StructType(fields) + StructType(fields.toSeq) } /** @@ -372,6 +372,6 @@ object RandomDataGenerator { fields += gen() } } - Row.fromSeq(fields) + Row.fromSeq(fields.toSeq) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala new file mode 100644 index 0000000000000..3d41d02b23df5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import java.io.File +import java.nio.file.Files + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.catalyst.util.fileToString + +trait SQLKeywordUtils extends SQLHelper { + + val sqlSyntaxDefs = { + val sqlBasePath = { + java.nio.file.Paths.get(sparkHome, "sql", "catalyst", "src", "main", "antlr4", "org", + "apache", "spark", "sql", "catalyst", "parser", "SqlBase.g4").toFile + } + fileToString(sqlBasePath).split("\n") + } + + // each element is an array of 4 string: the keyword name, reserve or not in Spark ANSI mode, + // Spark non-ANSI mode, and the SQL standard. + val keywordsInDoc: Array[Array[String]] = { + val docPath = { + java.nio.file.Paths.get(sparkHome, "docs", "sql-ref-ansi-compliance.md").toFile + } + fileToString(docPath).split("\n") + .dropWhile(!_.startsWith("|Keyword|")).drop(2).takeWhile(_.startsWith("|")) + .map(_.stripPrefix("|").split("\\|").map(_.trim)) + } + + private def parseAntlrGrammars[T](startTag: String, endTag: String) + (f: PartialFunction[String, Seq[T]]): Set[T] = { + val keywords = new mutable.ArrayBuffer[T] + val default = (_: String) => Nil + var startTagFound = false + var parseFinished = false + val lineIter = sqlSyntaxDefs.toIterator + while (!parseFinished && lineIter.hasNext) { + val line = lineIter.next() + if (line.trim.startsWith(startTag)) { + startTagFound = true + } else if (line.trim.startsWith(endTag)) { + parseFinished = true + } else if (startTagFound) { + f.applyOrElse(line, default).foreach { symbol => + keywords += symbol + } + } + } + assert(keywords.nonEmpty && startTagFound && parseFinished, "cannot extract keywords from " + + s"the `SqlBase.g4` file, so please check if the start/end tags (`$startTag` and `$endTag`) " + + "are placed correctly in the file.") + keywords.toSet + } + + // If a symbol does not have the same string with its literal (e.g., `SETMINUS: 'MINUS';`), + // we need to map a symbol to actual literal strings. + val symbolsToExpandIntoDifferentLiterals = { + val kwDef = """([A-Z_]+):(.+);""".r + val keywords = parseAntlrGrammars( + "//--SPARK-KEYWORD-LIST-START", "//--SPARK-KEYWORD-LIST-END") { + case kwDef(symbol, literalDef) => + val splitDefs = literalDef.split("""\|""") + val hasMultipleLiterals = splitDefs.length > 1 + // The case where a symbol has multiple literal definitions, + // e.g., `DATABASES: 'DATABASES' | 'SCHEMAS';`. + if (hasMultipleLiterals) { + // Filters out inappropriate entries, e.g., `!` in `NOT: 'NOT' | '!';` + val litDef = """([A-Z_]+)""".r + val literals = splitDefs.map(_.replaceAll("'", "").trim).toSeq.flatMap { + case litDef(lit) => Some(lit) + case _ => None + } + (symbol, literals) :: Nil + } else { + val literal = literalDef.replaceAll("'", "").trim + // The case where a symbol string and its literal string are different, + // e.g., `SETMINUS: 'MINUS';`. + if (symbol != literal) { + (symbol, literal :: Nil) :: Nil + } else { + Nil + } + } + } + keywords.toMap + } + + // All the SQL keywords defined in `SqlBase.g4` + val allCandidateKeywords: Set[String] = { + val kwDef = """([A-Z_]+):.+;""".r + parseAntlrGrammars( + "//--SPARK-KEYWORD-LIST-START", "//--SPARK-KEYWORD-LIST-END") { + // Parses a pattern, e.g., `AFTER: 'AFTER';` + case kwDef(symbol) => + if (symbolsToExpandIntoDifferentLiterals.contains(symbol)) { + symbolsToExpandIntoDifferentLiterals(symbol) + } else { + symbol :: Nil + } + } + } + + val nonReservedKeywordsInAnsiMode: Set[String] = { + val kwDef = """\s*[\|:]\s*([A-Z_]+)\s*""".r + parseAntlrGrammars("//--ANSI-NON-RESERVED-START", "//--ANSI-NON-RESERVED-END") { + // Parses a pattern, e.g., ` | AFTER` + case kwDef(symbol) => + if (symbolsToExpandIntoDifferentLiterals.contains(symbol)) { + symbolsToExpandIntoDifferentLiterals(symbol) + } else { + symbol :: Nil + } + } + } + + val reservedKeywordsInAnsiMode = allCandidateKeywords -- nonReservedKeywordsInAnsiMode +} + +class SQLKeywordSuite extends SparkFunSuite with SQLKeywordUtils { + test("all keywords are documented") { + val documentedKeywords = keywordsInDoc.map(_.head).toSet + if (allCandidateKeywords != documentedKeywords) { + val undocumented = (allCandidateKeywords -- documentedKeywords).toSeq.sorted + fail("Some keywords are not documented: " + undocumented.mkString(", ")) + } + } + + test("Spark keywords are documented correctly") { + val reservedKeywordsInDoc = keywordsInDoc.filter(_.apply(1) == "reserved").map(_.head).toSet + if (reservedKeywordsInAnsiMode != reservedKeywordsInDoc) { + val misImplemented = (reservedKeywordsInDoc -- reservedKeywordsInAnsiMode).toSeq.sorted + fail("Some keywords are documented as reserved but actually not: " + + misImplemented.mkString(", ")) + } + } + + test("SQL 2016 keywords are documented correctly") { + withTempDir { dir => + val tmpFile = new File(dir, "tmp") + val is = Thread.currentThread().getContextClassLoader + .getResourceAsStream("ansi-sql-2016-reserved-keywords.txt") + Files.copy(is, tmpFile.toPath) + val reservedKeywordsInSql2016 = Files.readAllLines(tmpFile.toPath) + .asScala.filterNot(_.startsWith("--")).map(_.trim).toSet + val documented = keywordsInDoc.filter(_.last == "reserved").map(_.head).toSet + assert((documented -- reservedKeywordsInSql2016).isEmpty) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index c15ec49e14282..c0be49af2107d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -884,4 +884,15 @@ class AnalysisSuite extends AnalysisTest with Matchers { Seq("Intersect can only be performed on tables with the compatible column types. " + "timestamp <> double at the second column of the second table")) } + + test("SPARK-31975: Throw user facing error when use WindowFunction directly") { + assertAnalysisError(testRelation2.select(RowNumber()), + Seq("Window function row_number() requires an OVER clause.")) + + assertAnalysisError(testRelation2.select(Sum(RowNumber())), + Seq("Window function row_number() requires an OVER clause.")) + + assertAnalysisError(testRelation2.select(RowNumber() + 1), + Seq("Window function row_number() requires an OVER clause.")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 76ec450a4d7c6..4ab288a34cb08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -49,9 +49,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { } protected def checkNullCast(from: DataType, to: DataType): Unit = { - withSQLConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP.key -> "true") { - checkEvaluation(cast(Literal.create(null, from), to, UTC_OPT), null) - } + checkEvaluation(cast(Literal.create(null, from), to, UTC_OPT), null) } test("null cast") { @@ -240,9 +238,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkCast(1.5, 1.5f) checkCast(1.5, "1.5") - withSQLConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP.key -> "true") { - checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) - } + checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) } test("cast from string") { @@ -309,19 +305,17 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { cast(cast("5", ByteType), ShortType), IntegerType), FloatType), DoubleType), LongType), 5.toLong) - withSQLConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP.key -> "true") { - checkEvaluation( - cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), - DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), - 5.toShort) - checkEvaluation( - cast(cast(cast(cast(cast(cast("5", TimestampType, UTC_OPT), ByteType), - DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), - null) - checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), - ByteType), TimestampType), LongType), StringType), ShortType), - 5.toShort) - } + checkEvaluation( + cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), + DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), + 5.toShort) + checkEvaluation( + cast(cast(cast(cast(cast(cast("5", TimestampType, UTC_OPT), ByteType), + DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), + null) + checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), + ByteType), TimestampType), LongType), StringType), ShortType), + 5.toShort) checkEvaluation(cast("23", DoubleType), 23d) checkEvaluation(cast("23", IntegerType), 23) @@ -383,31 +377,29 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(ts, FloatType), 15.003f) checkEvaluation(cast(ts, DoubleType), 15.003) - withSQLConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP.key -> "true") { - checkEvaluation(cast(cast(tss, ShortType), TimestampType), - fromJavaTimestamp(ts) * MILLIS_PER_SECOND) - checkEvaluation(cast(cast(tss, IntegerType), TimestampType), - fromJavaTimestamp(ts) * MILLIS_PER_SECOND) - checkEvaluation(cast(cast(tss, LongType), TimestampType), - fromJavaTimestamp(ts) * MILLIS_PER_SECOND) - checkEvaluation( - cast(cast(millis.toFloat / MILLIS_PER_SECOND, TimestampType), FloatType), - millis.toFloat / MILLIS_PER_SECOND) - checkEvaluation( - cast(cast(millis.toDouble / MILLIS_PER_SECOND, TimestampType), DoubleType), - millis.toDouble / MILLIS_PER_SECOND) - checkEvaluation( - cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT), - Decimal(1)) + checkEvaluation(cast(cast(tss, ShortType), TimestampType), + fromJavaTimestamp(ts) * MILLIS_PER_SECOND) + checkEvaluation(cast(cast(tss, IntegerType), TimestampType), + fromJavaTimestamp(ts) * MILLIS_PER_SECOND) + checkEvaluation(cast(cast(tss, LongType), TimestampType), + fromJavaTimestamp(ts) * MILLIS_PER_SECOND) + checkEvaluation( + cast(cast(millis.toFloat / MILLIS_PER_SECOND, TimestampType), FloatType), + millis.toFloat / MILLIS_PER_SECOND) + checkEvaluation( + cast(cast(millis.toDouble / MILLIS_PER_SECOND, TimestampType), DoubleType), + millis.toDouble / MILLIS_PER_SECOND) + checkEvaluation( + cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT), + Decimal(1)) - // A test for higher precision than millis - checkEvaluation(cast(cast(0.000001, TimestampType), DoubleType), 0.000001) + // A test for higher precision than millis + checkEvaluation(cast(cast(0.000001, TimestampType), DoubleType), 0.000001) - checkEvaluation(cast(Double.NaN, TimestampType), null) - checkEvaluation(cast(1.0 / 0.0, TimestampType), null) - checkEvaluation(cast(Float.NaN, TimestampType), null) - checkEvaluation(cast(1.0f / 0.0f, TimestampType), null) - } + checkEvaluation(cast(Double.NaN, TimestampType), null) + checkEvaluation(cast(1.0 / 0.0, TimestampType), null) + checkEvaluation(cast(Float.NaN, TimestampType), null) + checkEvaluation(cast(1.0f / 0.0f, TimestampType), null) } test("cast from array") { @@ -1036,10 +1028,8 @@ class CastSuite extends CastSuiteBase { test("cast from int 2") { checkEvaluation(cast(1, LongType), 1.toLong) - withSQLConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP.key -> "true") { - checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong) - checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong) - } + checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong) + checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong) checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) @@ -1323,7 +1313,7 @@ class CastSuite extends CastSuiteBase { } } - test("SPARK-31710:fail casting from numeric to timestamp by default") { + test("SPARK-31710: fail casting from numeric to timestamp if it is forbidden") { Seq(true, false).foreach { enable => withSQLConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP.key -> enable.toString) { assert(cast(2.toByte, TimestampType).resolved == enable) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 3a0c02b29d92c..856c1fad9b204 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -933,6 +933,13 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Literal(negateExact(stringToInterval("interval 1 month")))), EmptyRow, s"sequence boundaries: 0 to 2678400000000 by -${28 * MICROS_PER_DAY}") + + // SPARK-32133: Sequence step must be a day interval if start and end values are dates + checkExceptionInExpression[IllegalArgumentException](Sequence( + Cast(Literal("2011-03-01"), DateType), + Cast(Literal("2011-04-01"), DateType), + Option(Literal(stringToInterval("interval 1 hour")))), null, + "sequence step must be a day interval if start and end values are dates") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 5be37318ae6eb..bfa415afeab93 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -127,7 +127,7 @@ class ColumnPruningSuite extends PlanTest { val optimized = Optimize.execute(query) - val aliases = NestedColumnAliasingSuite.collectGeneratedAliases(optimized) + val aliases = NestedColumnAliasingSuite.collectGeneratedAliases(optimized).toSeq val selectedFields = UnresolvedAttribute("a") +: aliasedExprs(aliases) val finalSelectedExprs = Seq(UnresolvedAttribute("a"), $"${aliases(0)}".as("c.d")) ++ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala new file mode 100644 index 0000000000000..a3e0bbc57e639 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, WithFields} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + + +class CombineWithFieldsSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("CombineWithFields", FixedPoint(10), CombineWithFields) :: Nil + } + + private val testRelation = LocalRelation('a.struct('a1.int)) + + test("combines two WithFields") { + val originalQuery = testRelation + .select(Alias( + WithFields( + WithFields( + 'a, + Seq("b1"), + Seq(Literal(4))), + Seq("c1"), + Seq(Literal(5))), "out")()) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select(Alias(WithFields('a, Seq("b1", "c1"), Seq(Literal(4), Literal(5))), "out")()) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("combines three WithFields") { + val originalQuery = testRelation + .select(Alias( + WithFields( + WithFields( + WithFields( + 'a, + Seq("b1"), + Seq(Literal(4))), + Seq("c1"), + Seq(Literal(5))), + Seq("d1"), + Seq(Literal(6))), "out")()) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select(Alias(WithFields('a, Seq("b1", "c1", "d1"), Seq(4, 5, 6).map(Literal(_))), "out")()) + .analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index d7eb048ba8705..e2b599a7c090c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -284,6 +284,15 @@ class EliminateSortsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("SPARK-32318: should not remove orderBy in distribute statement") { + val projectPlan = testRelation.select('a, 'b) + val orderByPlan = projectPlan.orderBy('b.desc) + val distributedPlan = orderByPlan.distribute('a)(1) + val optimized = Optimize.execute(distributedPlan.analyze) + val correctAnswer = distributedPlan.analyze + comparePlans(optimized, correctAnswer) + } + test("should not remove orderBy in left join clause if there is an outer limit") { val projectPlan = testRelation.select('a, 'b) val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index bb8f5f90f8508..bb7e9d04c12d9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -33,7 +33,7 @@ class FilterPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - override protected val blacklistedOnceBatches: Set[String] = + override protected val excludedOnceBatches: Set[String] = Set("Push CNF predicate through join") val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala index f5af416602c9d..bb9919f94eef2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{And, IsNull, KnownFloatingPointNormalized} +import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If, IsNull, KnownFloatingPointNormalized} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -91,5 +91,38 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest { comparePlans(doubleOptimized, correctAnswer) } + + test("SPARK-32258: normalize the children of If") { + val cond = If(a > 0.1D, namedStruct("a", a), namedStruct("a", a + 0.2D)) === namedStruct("a", b) + val query = testRelation1.join(testRelation2, condition = Some(cond)) + val optimized = Optimize.execute(query) + val doubleOptimized = Optimize.execute(optimized) + + val joinCond = If(a > 0.1D, + namedStruct("a", KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), + namedStruct("a", KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.2D)))) === + namedStruct("a", KnownFloatingPointNormalized(NormalizeNaNAndZero(b))) + val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond)) + + comparePlans(doubleOptimized, correctAnswer) + } + + test("SPARK-32258: normalize the children of CaseWhen") { + val cond = CaseWhen( + Seq((a > 0.1D, namedStruct("a", a)), (a > 0.2D, namedStruct("a", a + 0.2D))), + Some(namedStruct("a", a + 0.3D))) === namedStruct("a", b) + val query = testRelation1.join(testRelation2, condition = Some(cond)) + val optimized = Optimize.execute(query) + val doubleOptimized = Optimize.execute(optimized) + + val joinCond = CaseWhen( + Seq((a > 0.1D, namedStruct("a", KnownFloatingPointNormalized(NormalizeNaNAndZero(a)))), + (a > 0.2D, namedStruct("a", KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.2D))))), + Some(namedStruct("a", KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.3D))))) === + namedStruct("a", KnownFloatingPointNormalized(NormalizeNaNAndZero(b))) + val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond)) + + comparePlans(doubleOptimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 9c7d4c7d8d233..dc323d4e5c77c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StructType} class PropagateEmptyRelationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -55,6 +55,9 @@ class PropagateEmptyRelationSuite extends PlanTest { val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) val testRelation2 = LocalRelation.fromExternalRows(Seq('b.int), data = Seq(Row(1))) + val metadata = new MetadataBuilder().putLong("test", 1).build() + val testRelation3 = + LocalRelation.fromExternalRows(Seq('c.int.notNull.withMetadata(metadata)), data = Seq(Row(1))) test("propagate empty relation through Union") { val query = testRelation1 @@ -67,6 +70,39 @@ class PropagateEmptyRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("SPARK-32241: remove empty relation children from Union") { + val query = testRelation1.union(testRelation2.where(false)) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation1 + comparePlans(optimized, correctAnswer) + + val query2 = testRelation1.where(false).union(testRelation2) + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer2 = testRelation2.select('b.as('a)).analyze + comparePlans(optimized2, correctAnswer2) + + val query3 = testRelation1.union(testRelation2.where(false)).union(testRelation3) + val optimized3 = Optimize.execute(query3.analyze) + val correctAnswer3 = testRelation1.union(testRelation3) + comparePlans(optimized3, correctAnswer3) + + val query4 = testRelation1.where(false).union(testRelation2).union(testRelation3) + val optimized4 = Optimize.execute(query4.analyze) + val correctAnswer4 = testRelation2.union(testRelation3).select('b.as('a)).analyze + comparePlans(optimized4, correctAnswer4) + + // Nullability can change from nullable to non-nullable + val query5 = testRelation1.where(false).union(testRelation3) + val optimized5 = Optimize.execute(query5.analyze) + assert(query5.output.head.nullable, "Original output should be nullable") + assert(!optimized5.output.head.nullable, "New output should be non-nullable") + + // Keep metadata + val query6 = testRelation3.where(false).union(testRelation1) + val optimized6 = Optimize.execute(query6.analyze) + assert(optimized6.output.head.metadata == metadata, "New output should keep metadata") + } + test("propagate empty relation through Join") { // Testcases are tuples of (left predicate, right predicate, joinType, correct answer) // Note that `None` is used to compare with OptimizeWithoutPropagateEmptyRelation. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala index 2d86d5a97e769..e7775705edc5e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor class PullupCorrelatedPredicatesSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - override protected val blacklistedOnceBatches = Set("PullupCorrelatedPredicates") + override protected val excludedOnceBatches = Set("PullupCorrelatedPredicates") val batches = Batch("PullupCorrelatedPredicates", Once, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index d55746002783a..c71e7dbe7d6f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -452,4 +452,61 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null) } + + private val structAttr = 'struct1.struct('a.int) + private val testStructRelation = LocalRelation(structAttr) + + test("simplify GetStructField on WithFields that is not changing the attribute being extracted") { + val query = testStructRelation.select( + GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 0, Some("a")) as "outerAtt") + val expected = testStructRelation.select(GetStructField('struct1, 0, Some("a")) as "outerAtt") + checkRule(query, expected) + } + + test("simplify GetStructField on WithFields that is changing the attribute being extracted") { + val query = testStructRelation.select( + GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 1, Some("b")) as "outerAtt") + val expected = testStructRelation.select(Literal(1) as "outerAtt") + checkRule(query, expected) + } + + test( + "simplify GetStructField on WithFields that is changing the attribute being extracted twice") { + val query = testStructRelation + .select(GetStructField(WithFields('struct1, Seq("b", "b"), Seq(Literal(1), Literal(2))), 1, + Some("b")) as "outerAtt") + val expected = testStructRelation.select(Literal(2) as "outerAtt") + checkRule(query, expected) + } + + test("collapse multiple GetStructField on the same WithFields") { + val query = testStructRelation + .select(WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2") + .select( + GetStructField('struct2, 0, Some("a")) as "struct1A", + GetStructField('struct2, 1, Some("b")) as "struct1B") + val expected = testStructRelation.select( + GetStructField('struct1, 0, Some("a")) as "struct1A", + Literal(2) as "struct1B") + checkRule(query, expected) + } + + test("collapse multiple GetStructField on different WithFields") { + val query = testStructRelation + .select( + WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2", + WithFields('struct1, Seq("b"), Seq(Literal(3))) as "struct3") + .select( + GetStructField('struct2, 0, Some("a")) as "struct2A", + GetStructField('struct2, 1, Some("b")) as "struct2B", + GetStructField('struct3, 0, Some("a")) as "struct3A", + GetStructField('struct3, 1, Some("b")) as "struct3B") + val expected = testStructRelation + .select( + GetStructField('struct1, 0, Some("a")) as "struct2A", + Literal(2) as "struct2B", + GetStructField('struct1, 0, Some("a")) as "struct3A", + Literal(3) as "struct3B") + checkRule(query, expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index d519fdf378786..655b1d26d6c90 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -61,6 +61,7 @@ class DataTypeParserSuite extends SparkFunSuite { checkDataType("varchAr(20)", StringType) checkDataType("cHaR(27)", StringType) checkDataType("BINARY", BinaryType) + checkDataType("void", NullType) checkDataType("interval", CalendarIntervalType) checkDataType("array", ArrayType(DoubleType, true)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index a721e17aef02d..f037ce7b9e793 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -16,19 +16,11 @@ */ package org.apache.spark.sql.catalyst.parser -import java.io.File -import java.nio.file.Files - -import scala.collection.JavaConverters._ -import scala.collection.mutable - import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.catalyst.util.fileToString +import org.apache.spark.sql.catalyst.{SQLKeywordUtils, TableIdentifier} import org.apache.spark.sql.internal.SQLConf -class TableIdentifierParserSuite extends SparkFunSuite with SQLHelper { +class TableIdentifierParserSuite extends SparkFunSuite with SQLKeywordUtils { import CatalystSqlParser._ // Add "$elem$", "$value$" & "$key$" @@ -292,121 +284,6 @@ class TableIdentifierParserSuite extends SparkFunSuite with SQLHelper { "where", "with") - private val sqlSyntaxDefs = { - val sqlBasePath = { - java.nio.file.Paths.get(sparkHome, "sql", "catalyst", "src", "main", "antlr4", "org", - "apache", "spark", "sql", "catalyst", "parser", "SqlBase.g4").toFile - } - fileToString(sqlBasePath).split("\n") - } - - private def parseAntlrGrammars[T](startTag: String, endTag: String) - (f: PartialFunction[String, Seq[T]]): Set[T] = { - val keywords = new mutable.ArrayBuffer[T] - val default = (_: String) => Nil - var startTagFound = false - var parseFinished = false - val lineIter = sqlSyntaxDefs.toIterator - while (!parseFinished && lineIter.hasNext) { - val line = lineIter.next() - if (line.trim.startsWith(startTag)) { - startTagFound = true - } else if (line.trim.startsWith(endTag)) { - parseFinished = true - } else if (startTagFound) { - f.applyOrElse(line, default).foreach { symbol => - keywords += symbol - } - } - } - assert(keywords.nonEmpty && startTagFound && parseFinished, "cannot extract keywords from " + - s"the `SqlBase.g4` file, so please check if the start/end tags (`$startTag` and `$endTag`) " + - "are placed correctly in the file.") - keywords.toSet - } - - // If a symbol does not have the same string with its literal (e.g., `SETMINUS: 'MINUS';`), - // we need to map a symbol to actual literal strings. - val symbolsToExpandIntoDifferentLiterals = { - val kwDef = """([A-Z_]+):(.+);""".r - val keywords = parseAntlrGrammars( - "//--SPARK-KEYWORD-LIST-START", "//--SPARK-KEYWORD-LIST-END") { - case kwDef(symbol, literalDef) => - val splitDefs = literalDef.split("""\|""") - val hasMultipleLiterals = splitDefs.length > 1 - // The case where a symbol has multiple literal definitions, - // e.g., `DATABASES: 'DATABASES' | 'SCHEMAS';`. - if (hasMultipleLiterals) { - // Filters out inappropriate entries, e.g., `!` in `NOT: 'NOT' | '!';` - val litDef = """([A-Z_]+)""".r - val literals = splitDefs.map(_.replaceAll("'", "").trim).toSeq.flatMap { - case litDef(lit) => Some(lit) - case _ => None - } - (symbol, literals) :: Nil - } else { - val literal = literalDef.replaceAll("'", "").trim - // The case where a symbol string and its literal string are different, - // e.g., `SETMINUS: 'MINUS';`. - if (symbol != literal) { - (symbol, literal :: Nil) :: Nil - } else { - Nil - } - } - } - keywords.toMap - } - - // All the SQL keywords defined in `SqlBase.g4` - val allCandidateKeywords = { - val kwDef = """([A-Z_]+):.+;""".r - val keywords = parseAntlrGrammars( - "//--SPARK-KEYWORD-LIST-START", "//--SPARK-KEYWORD-LIST-END") { - // Parses a pattern, e.g., `AFTER: 'AFTER';` - case kwDef(symbol) => - if (symbolsToExpandIntoDifferentLiterals.contains(symbol)) { - symbolsToExpandIntoDifferentLiterals(symbol) - } else { - symbol :: Nil - } - } - keywords - } - - val nonReservedKeywordsInAnsiMode = { - val kwDef = """\s*[\|:]\s*([A-Z_]+)\s*""".r - parseAntlrGrammars("//--ANSI-NON-RESERVED-START", "//--ANSI-NON-RESERVED-END") { - // Parses a pattern, e.g., ` | AFTER` - case kwDef(symbol) => - if (symbolsToExpandIntoDifferentLiterals.contains(symbol)) { - symbolsToExpandIntoDifferentLiterals(symbol) - } else { - symbol :: Nil - } - } - } - - val reservedKeywordsInAnsiMode = allCandidateKeywords -- nonReservedKeywordsInAnsiMode - - test("check # of reserved keywords") { - val numReservedKeywords = 74 - assert(reservedKeywordsInAnsiMode.size == numReservedKeywords, - s"The expected number of reserved keywords is $numReservedKeywords, but " + - s"${reservedKeywordsInAnsiMode.size} found.") - } - - test("reserved keywords in Spark are also reserved in SQL 2016") { - withTempDir { dir => - val tmpFile = new File(dir, "tmp") - val is = Thread.currentThread().getContextClassLoader - .getResourceAsStream("ansi-sql-2016-reserved-keywords.txt") - Files.copy(is, tmpFile.toPath) - val reservedKeywordsInSql2016 = Files.readAllLines(tmpFile.toPath) - .asScala.filterNot(_.startsWith("--")).map(_.trim).toSet - assert((reservedKeywordsInAnsiMode -- reservedKeywordsInSql2016).isEmpty) - } - } test("table identifier") { // Regular names. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala index 229e32479082c..f921f06537080 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.util -import scala.collection._ - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow, SpecificInternalRow, UnsafeMapData, UnsafeProjection} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TimestampFormatterSuite.scala index 51286986b835c..79c06cf8313b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TimestampFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TimestampFormatterSuite.scala @@ -288,14 +288,14 @@ class TimestampFormatterSuite extends DatetimeFormatterSuite { withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> zoneId.getId) { withDefaultTimeZone(zoneId) { withClue(s"zoneId = ${zoneId.getId}") { - val formatters = LegacyDateFormats.values.map { legacyFormat => + val formatters = LegacyDateFormats.values.toSeq.map { legacyFormat => TimestampFormatter( TimestampFormatter.defaultPattern, zoneId, TimestampFormatter.defaultLocale, legacyFormat, isParsing = false) - }.toSeq :+ TimestampFormatter.getFractionFormatter(zoneId) + } :+ TimestampFormatter.getFractionFormatter(zoneId) formatters.foreach { formatter => assert(microsToInstant(formatter.parse("1000-01-01 01:02:03")) .atZone(zoneId) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 3d7026e180cd1..616fc72320caf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector +import java.time.{Instant, ZoneId} +import java.time.temporal.ChronoUnit import java.util import scala.collection.JavaConverters._ @@ -25,12 +27,13 @@ import scala.collection.mutable import org.scalatest.Assertions._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.catalog._ -import org.apache.spark.sql.connector.expressions.{IdentityTransform, NamedReference, Transform} +import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, DateType, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap /** @@ -46,10 +49,15 @@ class InMemoryTable( private val allowUnsupportedTransforms = properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean - partitioning.foreach { t => - if (!t.isInstanceOf[IdentityTransform] && !allowUnsupportedTransforms) { - throw new IllegalArgumentException(s"Transform $t must be IdentityTransform") - } + partitioning.foreach { + case _: IdentityTransform => + case _: YearsTransform => + case _: MonthsTransform => + case _: DaysTransform => + case _: HoursTransform => + case _: BucketTransform => + case t if !allowUnsupportedTransforms => + throw new IllegalArgumentException(s"Transform $t is not a supported transform") } // The key `Seq[Any]` is the partition values. @@ -66,8 +74,14 @@ class InMemoryTable( } } + private val UTC = ZoneId.of("UTC") + private val EPOCH_LOCAL_DATE = Instant.EPOCH.atZone(UTC).toLocalDate + private def getKey(row: InternalRow): Seq[Any] = { - def extractor(fieldNames: Array[String], schema: StructType, row: InternalRow): Any = { + def extractor( + fieldNames: Array[String], + schema: StructType, + row: InternalRow): (Any, DataType) = { val index = schema.fieldIndex(fieldNames(0)) val value = row.toSeq(schema).apply(index) if (fieldNames.length > 1) { @@ -78,10 +92,44 @@ class InMemoryTable( throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}") } } else { - value + (value, schema(index).dataType) } } - partCols.map(fieldNames => extractor(fieldNames, schema, row)) + + partitioning.map { + case IdentityTransform(ref) => + extractor(ref.fieldNames, schema, row)._1 + case YearsTransform(ref) => + extractor(ref.fieldNames, schema, row) match { + case (days: Int, DateType) => + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) + case (micros: Long, TimestampType) => + val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate) + } + case MonthsTransform(ref) => + extractor(ref.fieldNames, schema, row) match { + case (days: Int, DateType) => + ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) + case (micros: Long, TimestampType) => + val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate + ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, localDate) + } + case DaysTransform(ref) => + extractor(ref.fieldNames, schema, row) match { + case (days, DateType) => + days + case (micros: Long, TimestampType) => + ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)) + } + case HoursTransform(ref) => + extractor(ref.fieldNames, schema, row) match { + case (micros: Long, TimestampType) => + ChronoUnit.HOURS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)) + } + case BucketTransform(numBuckets, ref) => + (extractor(ref.fieldNames, schema, row).hashCode() & Integer.MAX_VALUE) % numBuckets + } } def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala index 1a262d646ca10..9fa016146bbd3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala @@ -502,6 +502,6 @@ abstract class DataTypeWriteCompatibilityBaseSuite extends SparkFunSuite { DataType.canWrite(writeType, readType, byName, analysis.caseSensitiveResolution, name, storeAssignmentPolicy, errMsg => errs += errMsg) === false, desc) assert(errs.size === numErrs, s"Should produce $numErrs error messages") - checkErrors(errs) + checkErrors(errs.toSeq) } } diff --git a/sql/core/src/main/scala-2.13/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala-2.13/org/apache/spark/sql/execution/streaming/StreamProgress.scala index 0aa29640899c6..6aa1b46cbb94a 100644 --- a/sql/core/src/main/scala-2.13/org/apache/spark/sql/execution/streaming/StreamProgress.scala +++ b/sql/core/src/main/scala-2.13/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, SparkD */ class StreamProgress( val baseMap: immutable.Map[SparkDataStream, OffsetV2] = - new immutable.HashMap[SparkDataStream, OffsetV2]) + new immutable.HashMap[SparkDataStream, OffsetV2]) extends scala.collection.immutable.Map[SparkDataStream, OffsetV2] { // Note: this class supports Scala 2.13. A parallel source tree has a 2.12 implementation. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index e6f7b1d723af6..da542c67d9c51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -871,6 +871,72 @@ class Column(val expr: Expression) extends Logging { */ def getItem(key: Any): Column = withExpr { UnresolvedExtractValue(expr, Literal(key)) } + // scalastyle:off line.size.limit + /** + * An expression that adds/replaces field in `StructType` by name. + * + * {{{ + * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + * df.select($"struct_col".withField("c", lit(3))) + * // result: {"a":1,"b":2,"c":3} + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + * df.select($"struct_col".withField("b", lit(3))) + * // result: {"a":1,"b":3} + * + * val df = sql("SELECT CAST(NULL AS struct) struct_col") + * df.select($"struct_col".withField("c", lit(3))) + * // result: null of type struct + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col") + * df.select($"struct_col".withField("b", lit(100))) + * // result: {"a":1,"b":100,"b":100} + * + * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + * df.select($"struct_col".withField("a.c", lit(3))) + * // result: {"a":{"a":1,"b":2,"c":3}} + * + * val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col") + * df.select($"struct_col".withField("a.c", lit(3))) + * // result: org.apache.spark.sql.AnalysisException: Ambiguous reference to fields + * }}} + * + * @group expr_ops + * @since 3.1.0 + */ + // scalastyle:on line.size.limit + def withField(fieldName: String, col: Column): Column = withExpr { + require(fieldName != null, "fieldName cannot be null") + require(col != null, "col cannot be null") + + val nameParts = if (fieldName.isEmpty) { + fieldName :: Nil + } else { + CatalystSqlParser.parseMultipartIdentifier(fieldName) + } + withFieldHelper(expr, nameParts, Nil, col.expr) + } + + private def withFieldHelper( + struct: Expression, + namePartsRemaining: Seq[String], + namePartsDone: Seq[String], + value: Expression) : WithFields = { + val name = namePartsRemaining.head + if (namePartsRemaining.length == 1) { + WithFields(struct, name :: Nil, value :: Nil) + } else { + val newNamesRemaining = namePartsRemaining.tail + val newNamesDone = namePartsDone :+ name + val newValue = withFieldHelper( + struct = UnresolvedExtractValue(struct, Literal(name)), + namePartsRemaining = newNamesRemaining, + namePartsDone = newNamesDone, + value = value) + WithFields(struct, name :: Nil, newValue :: Nil) + } + } + /** * An expression that gets a field by name in a `StructType`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6f97121d88ede..d5501326397c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2030,7 +2030,47 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.3.0 */ - def unionByName(other: Dataset[T]): Dataset[T] = withSetOperator { + def unionByName(other: Dataset[T]): Dataset[T] = unionByName(other, false) + + /** + * Returns a new Dataset containing union of rows in this Dataset and another Dataset. + * + * The difference between this function and [[union]] is that this function + * resolves columns by name (not by position). + * + * When the parameter `allowMissingColumns` is true, this function allows different set + * of column names between two Datasets. Missing columns at each side, will be filled with + * null values. The missing columns at left Dataset will be added at the end in the schema + * of the union result: + * + * {{{ + * val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2") + * val df2 = Seq((4, 5, 6)).toDF("col1", "col0", "col3") + * df1.unionByName(df2, true).show + * + * // output: "col3" is missing at left df1 and added at the end of schema. + * // +----+----+----+----+ + * // |col0|col1|col2|col3| + * // +----+----+----+----+ + * // | 1| 2| 3|null| + * // | 5| 4|null| 6| + * // +----+----+----+----+ + * + * df2.unionByName(df1, true).show + * + * // output: "col2" is missing at left df2 and added at the end of schema. + * // +----+----+----+----+ + * // |col1|col0|col3|col2| + * // +----+----+----+----+ + * // | 4| 5| 6|null| + * // | 2| 1|null| 3| + * // +----+----+----+----+ + * }}} + * + * @group typedrel + * @since 3.1.0 + */ + def unionByName(other: Dataset[T], allowMissingColumns: Boolean): Dataset[T] = withSetOperator { // Check column name duplication val resolver = sparkSession.sessionState.analyzer.resolver val leftOutputAttrs = logicalPlan.output @@ -2048,9 +2088,13 @@ class Dataset[T] private[sql]( // Builds a project list for `other` based on `logicalPlan` output names val rightProjectList = leftOutputAttrs.map { lattr => rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse { - throw new AnalysisException( - s"""Cannot resolve column name "${lattr.name}" among """ + - s"""(${rightOutputAttrs.map(_.name).mkString(", ")})""") + if (allowMissingColumns) { + Alias(Literal(null, lattr.dataType), lattr.name)() + } else { + throw new AnalysisException( + s"""Cannot resolve column name "${lattr.name}" among """ + + s"""(${rightOutputAttrs.map(_.name).mkString(", ")})""") + } } } @@ -2058,9 +2102,20 @@ class Dataset[T] private[sql]( val notFoundAttrs = rightOutputAttrs.diff(rightProjectList) val rightChild = Project(rightProjectList ++ notFoundAttrs, other.logicalPlan) + // Builds a project for `logicalPlan` based on `other` output names, if allowing + // missing columns. + val leftChild = if (allowMissingColumns) { + val missingAttrs = notFoundAttrs.map { attr => + Alias(Literal(null, attr.dataType), attr.name)() + } + Project(leftOutputAttrs ++ missingAttrs, logicalPlan) + } else { + logicalPlan + } + // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, rightChild)) + CombineUnions(Union(leftChild, rightChild)) } /** @@ -3532,7 +3587,7 @@ class Dataset[T] private[sql]( val numPartitions = arrowBatchRdd.partitions.length // Store collection results for worst case of 1 to N-1 partitions - val results = new Array[Array[Array[Byte]]](numPartitions - 1) + val results = new Array[Array[Array[Byte]]](Math.max(0, numPartitions - 1)) var lastIndex = -1 // index of last partition written // Handler to eagerly write partitions to Python in order diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index c37d8eaa294bf..611c03e7b208e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -479,7 +479,7 @@ class RelationalGroupedDataset protected[sql]( * @since 2.4.0 */ def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = { - pivot(pivotColumn, values.asScala) + pivot(pivotColumn, values.asScala.toSeq) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index ea1a9f12cd24b..08b0a1c6a60a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -372,7 +372,7 @@ class SparkSession private( */ @DeveloperApi def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = withActive { - Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) + Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala.toSeq)) } /** @@ -495,7 +495,7 @@ class SparkSession private( * @since 2.0.0 */ def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { - createDataset(data.asScala) + createDataset(data.asScala.toSeq) } /** @@ -1087,7 +1087,7 @@ object SparkSession extends Logging { } private def assertOnDriver(): Unit = { - if (Utils.isTesting && TaskContext.get != null) { + if (TaskContext.get != null) { // we're accessing it during task execution, fail. throw new IllegalStateException( "SparkSession should only be created and accessed on the driver.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index 1c2bf9e7c2a57..ff706b5061f0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -103,7 +103,7 @@ class SparkSessionExtensions { * Build the override rules for columnar execution. */ private[sql] def buildColumnarRules(session: SparkSession): Seq[ColumnarRule] = { - columnarRuleBuilders.map(_.apply(session)) + columnarRuleBuilders.map(_.apply(session)).toSeq } /** @@ -119,7 +119,7 @@ class SparkSessionExtensions { * Build the analyzer resolution `Rule`s using the given [[SparkSession]]. */ private[sql] def buildResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { - resolutionRuleBuilders.map(_.apply(session)) + resolutionRuleBuilders.map(_.apply(session)).toSeq } /** @@ -136,7 +136,7 @@ class SparkSessionExtensions { * Build the analyzer post-hoc resolution `Rule`s using the given [[SparkSession]]. */ private[sql] def buildPostHocResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { - postHocResolutionRuleBuilders.map(_.apply(session)) + postHocResolutionRuleBuilders.map(_.apply(session)).toSeq } /** @@ -153,7 +153,7 @@ class SparkSessionExtensions { * Build the check analysis `Rule`s using the given [[SparkSession]]. */ private[sql] def buildCheckRules(session: SparkSession): Seq[LogicalPlan => Unit] = { - checkRuleBuilders.map(_.apply(session)) + checkRuleBuilders.map(_.apply(session)).toSeq } /** @@ -168,7 +168,7 @@ class SparkSessionExtensions { private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder] private[sql] def buildOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { - optimizerRules.map(_.apply(session)) + optimizerRules.map(_.apply(session)).toSeq } /** @@ -184,7 +184,7 @@ class SparkSessionExtensions { private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = { - plannerStrategyBuilders.map(_.apply(session)) + plannerStrategyBuilders.map(_.apply(session)).toSeq } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index ced4af46c3f30..0f6ae9c5d44e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -133,9 +133,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | * @since 1.3.0 | */ |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { - | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + | val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) | val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = $inputEncoders - | val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + | val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) | val finalUdf = if (nullable) udf else udf.asNonNullable() | def builder(e: Seq[Expression]) = if (e.length == $x) { | finalUdf.createScalaUDF(e) @@ -179,9 +180,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 0) { finalUdf.createScalaUDF(e) @@ -199,9 +201,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 1) { finalUdf.createScalaUDF(e) @@ -219,9 +222,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 2) { finalUdf.createScalaUDF(e) @@ -239,9 +243,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 3) { finalUdf.createScalaUDF(e) @@ -259,9 +264,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 4) { finalUdf.createScalaUDF(e) @@ -279,9 +285,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 5) { finalUdf.createScalaUDF(e) @@ -299,9 +306,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 6) { finalUdf.createScalaUDF(e) @@ -319,9 +327,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 7) { finalUdf.createScalaUDF(e) @@ -339,9 +348,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 8) { finalUdf.createScalaUDF(e) @@ -359,9 +369,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 9) { finalUdf.createScalaUDF(e) @@ -379,9 +390,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 10) { finalUdf.createScalaUDF(e) @@ -399,9 +411,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 11) { finalUdf.createScalaUDF(e) @@ -419,9 +432,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 12) { finalUdf.createScalaUDF(e) @@ -439,9 +453,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 13) { finalUdf.createScalaUDF(e) @@ -459,9 +474,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 14) { finalUdf.createScalaUDF(e) @@ -479,9 +495,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 15) { finalUdf.createScalaUDF(e) @@ -499,9 +516,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 16) { finalUdf.createScalaUDF(e) @@ -519,9 +537,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 17) { finalUdf.createScalaUDF(e) @@ -539,9 +558,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 18) { finalUdf.createScalaUDF(e) @@ -559,9 +579,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 19) { finalUdf.createScalaUDF(e) @@ -579,9 +600,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 20) { finalUdf.createScalaUDF(e) @@ -599,9 +621,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Try(ExpressionEncoder[A21]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 21) { finalUdf.createScalaUDF(e) @@ -619,9 +642,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 1.3.0 */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + val outputEncoder = Try(ExpressionEncoder[RT]()).toOption + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Try(ExpressionEncoder[A21]()).toOption :: Try(ExpressionEncoder[A22]()).toOption :: Nil - val udf = SparkUserDefinedFunction(func, dataType, inputEncoders).withName(name) + val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 22) { finalUdf.createScalaUDF(e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index bf90875e511f8..bc3f38a35834d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -48,6 +48,7 @@ class ResolveSessionCatalog( override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case AlterTableAddColumnsStatement( nameParts @ SessionCatalogAndTable(catalog, tbl), cols) => + cols.foreach(c => failNullType(c.dataType)) loadTable(catalog, tbl.asIdentifier).collect { case v1Table: V1Table => if (!DDLUtils.isHiveTable(v1Table.v1Table)) { @@ -76,6 +77,7 @@ class ResolveSessionCatalog( case AlterTableReplaceColumnsStatement( nameParts @ SessionCatalogAndTable(catalog, tbl), cols) => + cols.foreach(c => failNullType(c.dataType)) val changes: Seq[TableChange] = loadTable(catalog, tbl.asIdentifier) match { case Some(_: V1Table) => throw new AnalysisException("REPLACE COLUMNS is only supported with v2 tables.") @@ -100,6 +102,7 @@ class ResolveSessionCatalog( case a @ AlterTableAlterColumnStatement( nameParts @ SessionCatalogAndTable(catalog, tbl), _, _, _, _, _) => + a.dataType.foreach(failNullType) loadTable(catalog, tbl.asIdentifier).collect { case v1Table: V1Table => if (!DDLUtils.isHiveTable(v1Table.v1Table)) { @@ -268,6 +271,7 @@ class ResolveSessionCatalog( // session catalog and the table provider is not v2. case c @ CreateTableStatement( SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _) => + assertNoNullTypeInSchema(c.tableSchema) val provider = c.provider.getOrElse(conf.defaultDataSourceName) if (!isV2Provider(provider)) { if (!DDLUtils.isHiveTable(Some(provider))) { @@ -292,6 +296,9 @@ class ResolveSessionCatalog( case c @ CreateTableAsSelectStatement( SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _) => + if (c.asSelect.resolved) { + assertNoNullTypeInSchema(c.asSelect.schema) + } val provider = c.provider.getOrElse(conf.defaultDataSourceName) if (!isV2Provider(provider)) { val tableDesc = buildCatalogTable(tbl.asTableIdentifier, new StructType, @@ -319,6 +326,7 @@ class ResolveSessionCatalog( // session catalog and the table provider is not v2. case c @ ReplaceTableStatement( SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _) => + assertNoNullTypeInSchema(c.tableSchema) val provider = c.provider.getOrElse(conf.defaultDataSourceName) if (!isV2Provider(provider)) { throw new AnalysisException("REPLACE TABLE is only supported with v2 tables.") @@ -336,6 +344,9 @@ class ResolveSessionCatalog( case c @ ReplaceTableAsSelectStatement( SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _) => + if (c.asSelect.resolved) { + assertNoNullTypeInSchema(c.asSelect.schema) + } val provider = c.provider.getOrElse(conf.defaultDataSourceName) if (!isV2Provider(provider)) { throw new AnalysisException("REPLACE TABLE AS SELECT is only supported with v2 tables.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala index 9807b5dbe9348..94e159c562e31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala @@ -257,16 +257,16 @@ object AggregatingAccumulator { imperative }) - val updateAttrSeq: AttributeSeq = aggBufferAttributes ++ inputAttributes - val mergeAttrSeq: AttributeSeq = aggBufferAttributes ++ inputAggBufferAttributes - val aggBufferAttributesSeq: AttributeSeq = aggBufferAttributes + val updateAttrSeq: AttributeSeq = (aggBufferAttributes ++ inputAttributes).toSeq + val mergeAttrSeq: AttributeSeq = (aggBufferAttributes ++ inputAggBufferAttributes).toSeq + val aggBufferAttributesSeq: AttributeSeq = aggBufferAttributes.toSeq // Create the accumulator. new AggregatingAccumulator( - aggBufferAttributes.map(_.dataType), - initialValues, - updateExpressions.map(BindReferences.bindReference(_, updateAttrSeq)), - mergeExpressions.map(BindReferences.bindReference(_, mergeAttrSeq)), + aggBufferAttributes.map(_.dataType).toSeq, + initialValues.toSeq, + updateExpressions.map(BindReferences.bindReference(_, updateAttrSeq)).toSeq, + mergeExpressions.map(BindReferences.bindReference(_, mergeAttrSeq)).toSeq, resultExpressions.map(BindReferences.bindReference(_, aggBufferAttributesSeq)), imperatives.toArray, typedImperatives.toArray, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala new file mode 100644 index 0000000000000..22bf6df58b040 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io.OutputStream +import java.nio.charset.StandardCharsets +import java.util.concurrent.TimeUnit + +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeSet, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType +import org.apache.spark.util.{CircularBuffer, SerializableConfiguration, Utils} + +trait BaseScriptTransformationExec extends UnaryExecNode { + + override def producedAttributes: AttributeSet = outputSet -- inputSet + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def doExecute(): RDD[InternalRow] = { + val broadcastedHadoopConf = + new SerializableConfiguration(sqlContext.sessionState.newHadoopConf()) + + child.execute().mapPartitions { iter => + if (iter.hasNext) { + val proj = UnsafeProjection.create(schema) + processIterator(iter, broadcastedHadoopConf.value).map(proj) + } else { + // If the input iterator has no rows then do not launch the external script. + Iterator.empty + } + } + } + + def processIterator( + inputIterator: Iterator[InternalRow], + hadoopConf: Configuration): Iterator[InternalRow] + + protected def checkFailureAndPropagate( + writerThread: BaseScriptTransformationWriterThread, + cause: Throwable = null, + proc: Process, + stderrBuffer: CircularBuffer): Unit = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + + // There can be a lag between reader read EOF and the process termination. + // If the script fails to startup, this kind of error may be missed. + // So explicitly waiting for the process termination. + val timeout = conf.getConf(SQLConf.SCRIPT_TRANSFORMATION_EXIT_TIMEOUT) + val exitRes = proc.waitFor(timeout, TimeUnit.SECONDS) + if (!exitRes) { + log.warn(s"Transformation script process exits timeout in $timeout seconds") + } + + if (!proc.isAlive) { + val exitCode = proc.exitValue() + if (exitCode != 0) { + logError(stderrBuffer.toString) // log the stderr circular buffer + throw new SparkException(s"Subprocess exited with status $exitCode. " + + s"Error: ${stderrBuffer.toString}", cause) + } + } + } +} + +abstract class BaseScriptTransformationWriterThread( + iter: Iterator[InternalRow], + inputSchema: Seq[DataType], + ioSchema: BaseScriptTransformIOSchema, + outputStream: OutputStream, + proc: Process, + stderrBuffer: CircularBuffer, + taskContext: TaskContext, + conf: Configuration) extends Thread with Logging { + + setDaemon(true) + + @volatile protected var _exception: Throwable = null + + /** Contains the exception thrown while writing the parent iterator to the external process. */ + def exception: Option[Throwable] = Option(_exception) + + protected def processRows(): Unit + + protected def processRowsWithoutSerde(): Unit = { + val len = inputSchema.length + iter.foreach { row => + val data = if (len == 0) { + ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") + } else { + val sb = new StringBuilder + sb.append(row.get(0, inputSchema(0))) + var i = 1 + while (i < len) { + sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + sb.append(row.get(i, inputSchema(i))) + i += 1 + } + sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) + sb.toString() + } + outputStream.write(data.getBytes(StandardCharsets.UTF_8)) + } + } + + override def run(): Unit = Utils.logUncaughtExceptions { + TaskContext.setTaskContext(taskContext) + + // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so + // let's use a variable to record whether the `finally` block was hit due to an exception + var threwException: Boolean = true + try { + processRows() + threwException = false + } catch { + // SPARK-25158 Exception should not be thrown again, otherwise it will be captured by + // SparkUncaughtExceptionHandler, then Executor will exit because of this Uncaught Exception, + // so pass the exception to `ScriptTransformationExec` is enough. + case t: Throwable => + // An error occurred while writing input, so kill the child process. According to the + // Javadoc this call will not throw an exception: + _exception = t + proc.destroy() + logError("Thread-ScriptTransformation-Feed exit cause by: ", t) + } finally { + try { + Utils.tryLogNonFatalError(outputStream.close()) + if (proc.waitFor() != 0) { + logError(stderrBuffer.toString) // log the stderr circular buffer + } + } catch { + case NonFatal(exceptionFromFinallyBlock) => + if (!threwException) { + throw exceptionFromFinallyBlock + } else { + log.error("Exception in finally block", exceptionFromFinallyBlock) + } + } + } + } +} + +/** + * The wrapper class of input and output schema properties + */ +abstract class BaseScriptTransformIOSchema extends Serializable { + import ScriptIOSchema._ + + def inputRowFormat: Seq[(String, String)] + + def outputRowFormat: Seq[(String, String)] + + def inputSerdeClass: Option[String] + + def outputSerdeClass: Option[String] + + def inputSerdeProps: Seq[(String, String)] + + def outputSerdeProps: Seq[(String, String)] + + def recordReaderClass: Option[String] + + def recordWriterClass: Option[String] + + def schemaLess: Boolean + + val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) + val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) +} + +object ScriptIOSchema { + val defaultFormat = Map( + ("TOK_TABLEROWFORMATFIELD", "\t"), + ("TOK_TABLEROWFORMATLINES", "\n") + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3a2c673229c20..363282ea95997 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -229,14 +229,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { */ override def visitNestedConstantList( ctx: NestedConstantListContext): Seq[Seq[String]] = withOrigin(ctx) { - ctx.constantList.asScala.map(visitConstantList) + ctx.constantList.asScala.map(visitConstantList).toSeq } /** * Convert a constants list into a String sequence. */ override def visitConstantList(ctx: ConstantListContext): Seq[String] = withOrigin(ctx) { - ctx.constant.asScala.map(visitStringConstant) + ctx.constant.asScala.map(visitStringConstant).toSeq } /** @@ -355,7 +355,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { // Storage format val defaultStorage = HiveSerDe.getDefaultStorage(conf) - validateRowFormatFileFormat(ctx.rowFormat.asScala, ctx.createFileFormat.asScala, ctx) + validateRowFormatFileFormat( + ctx.rowFormat.asScala.toSeq, ctx.createFileFormat.asScala.toSeq, ctx) val fileStorage = ctx.createFileFormat.asScala.headOption.map(visitCreateFileFormat) .getOrElse(CatalogStorageFormat.empty) val rowStorage = ctx.rowFormat.asScala.headOption.map(visitRowFormat) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3f339347ab4db..7b5d8f15962d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -159,7 +159,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // 4. Pick cartesian product if join type is inner like. // 5. Pick broadcast nested loop join as the final solution. It may OOM but we don't have // other choice. - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) => + case j @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, nonEquiCond, left, right, hint) => def createBroadcastHashJoin(onlyLookingAtHint: Boolean) = { getBroadcastBuildSide(left, right, joinType, hint, onlyLookingAtHint, conf).map { buildSide => @@ -168,7 +168,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { rightKeys, joinType, buildSide, - condition, + nonEquiCond, planLater(left), planLater(right))) } @@ -182,7 +182,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { rightKeys, joinType, buildSide, - condition, + nonEquiCond, planLater(left), planLater(right))) } @@ -191,7 +191,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def createSortMergeJoin() = { if (RowOrdering.isOrderable(leftKeys)) { Some(Seq(joins.SortMergeJoinExec( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)))) + leftKeys, rightKeys, joinType, nonEquiCond, planLater(left), planLater(right)))) } else { None } @@ -199,7 +199,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def createCartesianProduct() = { if (joinType.isInstanceOf[InnerLike]) { - Some(Seq(joins.CartesianProductExec(planLater(left), planLater(right), condition))) + // `CartesianProductExec` can't implicitly evaluate equal join condition, here we should + // pass the original condition which includes both equal and non-equal conditions. + Some(Seq(joins.CartesianProductExec(planLater(left), planLater(right), j.condition))) } else { None } @@ -220,7 +222,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // This join could be very slow or OOM val buildSide = getSmallerSide(left, right) Seq(joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), buildSide, joinType, condition)) + planLater(left), planLater(right), buildSide, joinType, nonEquiCond)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 0244542054611..558d990e8c4bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -263,7 +263,7 @@ trait CodegenSupport extends SparkPlan { paramVars += ExprCode(paramIsNull, JavaCode.variable(paramName, attributes(i).dataType)) } - (arguments, parameters, paramVars) + (arguments.toSeq, parameters.toSeq, paramVars.toSeq) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index bc924e6978ddc..112090640040a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -196,7 +196,7 @@ case class AdaptiveSparkPlanExec( // In case of errors, we cancel all running stages and throw exception. if (errors.nonEmpty) { - cleanUpAndThrowException(errors, None) + cleanUpAndThrowException(errors.toSeq, None) } // Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala index 3cf6a13a4a892..8d7a2c95081c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala @@ -65,7 +65,7 @@ trait AdaptiveSparkPlanHelper { def mapPlans[A](p: SparkPlan)(f: SparkPlan => A): Seq[A] = { val ret = new collection.mutable.ArrayBuffer[A]() foreach(p)(ret += f(_)) - ret + ret.toSeq } /** @@ -75,7 +75,7 @@ trait AdaptiveSparkPlanHelper { def flatMap[A](p: SparkPlan)(f: SparkPlan => TraversableOnce[A]): Seq[A] = { val ret = new collection.mutable.ArrayBuffer[A]() foreach(p)(ret ++= f(_)) - ret + ret.toSeq } /** @@ -86,7 +86,7 @@ trait AdaptiveSparkPlanHelper { val ret = new collection.mutable.ArrayBuffer[B]() val lifted = pf.lift foreach(p)(node => lifted(node).foreach(ret.+=)) - ret + ret.toSeq } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala index 17be37b1fb27e..b3b38b5e8ee85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala @@ -154,7 +154,7 @@ case class CustomShuffleReaderExec private( partitionDataSizeMetrics.set(dataSizes.sum) } - SQLMetrics.postDriverMetricsUpdatedByValue(sparkContext, executionId, driverAccumUpdates) + SQLMetrics.postDriverMetricsUpdatedByValue(sparkContext, executionId, driverAccumUpdates.toSeq) } @transient override lazy val metrics: Map[String, SQLMetric] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index b6bb48ae9cc38..6f7b7aca2029c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -291,8 +291,8 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight") if (numSkewedLeft > 0 || numSkewedRight > 0) { - val newLeft = CustomShuffleReaderExec(left.shuffleStage, leftSidePartitions) - val newRight = CustomShuffleReaderExec(right.shuffleStage, rightSidePartitions) + val newLeft = CustomShuffleReaderExec(left.shuffleStage, leftSidePartitions.toSeq) + val newRight = CustomShuffleReaderExec(right.shuffleStage, rightSidePartitions.toSeq) val newSmj = replaceSkewedShufleReader( replaceSkewedShufleReader(smj, newLeft), newRight).asInstanceOf[SortMergeJoinExec] newSmj.copy(isSkewJoin = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala index d6e44b780d772..83fdafbadcb60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala @@ -121,7 +121,7 @@ object ShufflePartitionsUtil extends Logging { i += 1 } createPartitionSpec() - partitionSpecs + partitionSpecs.toSeq } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 544b90a736071..44bc9c2e3a9d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ @@ -458,7 +460,8 @@ case class ScalaUDAF( case class ScalaAggregator[IN, BUF, OUT]( children: Seq[Expression], agg: Aggregator[IN, BUF, OUT], - inputEncoderNR: ExpressionEncoder[IN], + inputEncoder: ExpressionEncoder[IN], + bufferEncoder: ExpressionEncoder[BUF], nullable: Boolean = true, isDeterministic: Boolean = true, mutableAggBufferOffset: Int = 0, @@ -469,9 +472,8 @@ case class ScalaAggregator[IN, BUF, OUT]( with ImplicitCastInputTypes with Logging { - private[this] lazy val inputDeserializer = inputEncoderNR.resolveAndBind().createDeserializer() - private[this] lazy val bufferEncoder = - agg.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]].resolveAndBind() + // input and buffer encoders are resolved by ResolveEncodersInScalaAgg + private[this] lazy val inputDeserializer = inputEncoder.createDeserializer() private[this] lazy val bufferSerializer = bufferEncoder.createSerializer() private[this] lazy val bufferDeserializer = bufferEncoder.createDeserializer() private[this] lazy val outputEncoder = agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]] @@ -479,7 +481,7 @@ case class ScalaAggregator[IN, BUF, OUT]( def dataType: DataType = outputEncoder.objSerializer.dataType - def inputTypes: Seq[DataType] = inputEncoderNR.schema.map(_.dataType) + def inputTypes: Seq[DataType] = inputEncoder.schema.map(_.dataType) override lazy val deterministic: Boolean = isDeterministic @@ -517,3 +519,18 @@ case class ScalaAggregator[IN, BUF, OUT]( override def nodeName: String = agg.getClass.getSimpleName } + +/** + * An extension rule to resolve encoder expressions from a [[ScalaAggregator]] + */ +object ResolveEncodersInScalaAgg extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case p if !p.resolved => p + case p => p.transformExpressionsUp { + case agg: ScalaAggregator[_, _, _] => + agg.copy( + inputEncoder = agg.inputEncoder.resolveAndBind(), + bufferEncoder = agg.bufferEncoder.resolveAndBind()) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index c8fa07941af87..cf9f3ddeb42a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -123,7 +123,7 @@ case class CachedRDDBuilder( rowCountStats.add(rowCount) val stats = InternalRow.fromSeq( - columnBuilders.flatMap(_.columnStats.collectedStatistics)) + columnBuilders.flatMap(_.columnStats.collectedStatistics).toSeq) CachedBatch(rowCount, columnBuilders.map { builder => JavaUtils.bufferToArray(builder.build()) }, stats) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala index 33b29bde93ee5..fc62dce5002b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala @@ -69,7 +69,7 @@ case class AnalyzePartitionCommand( if (filteredSpec.isEmpty) { None } else { - Some(filteredSpec) + Some(filteredSpec.toMap) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 47b213fc2d83b..d550fe270c753 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -650,7 +650,7 @@ case class AlterTableRecoverPartitionsCommand( val pathFilter = getPathFilter(hadoopConf) val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8) - val partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)] = + val partitionSpecsAndLocs: GenSeq[(TablePartitionSpec, Path)] = try { scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold, spark.sessionState.conf.resolver, new ForkJoinTaskSupport(evalPool)).seq @@ -697,7 +697,7 @@ case class AlterTableRecoverPartitionsCommand( // parallelize the list of partitions here, then we can have better parallelism later. val parArray = new ParVector(statuses.toVector) parArray.tasksupport = evalTaskSupport - parArray + parArray.seq } else { statuses } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index fc8cc11bb1067..7aebdddf1d59c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -657,7 +657,7 @@ case class DescribeTableCommand( } } - result + result.toSeq } private def describePartitionInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { @@ -740,7 +740,7 @@ case class DescribeQueryCommand(queryText: String, plan: LogicalPlan) val result = new ArrayBuffer[Row] val queryExecution = sparkSession.sessionState.executePlan(plan) describeSchema(queryExecution.analyzed.schema, result, header = false) - result + result.toSeq } } @@ -815,7 +815,7 @@ case class DescribeColumnCommand( } yield histogramDescription(hist) buffer ++= histDesc.getOrElse(Seq(Row("histogram", "NULL"))) } - buffer + buffer.toSeq } private def histogramDescription(histogram: Histogram): Seq[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 07d7c4e97a095..db564485be883 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -811,7 +811,7 @@ object DataSource extends Logging { val path = CaseInsensitiveMap(options).get("path") val optionsWithoutPath = options.filterKeys(_.toLowerCase(Locale.ROOT) != "path") CatalogStorageFormat.empty.copy( - locationUri = path.map(CatalogUtils.stringToURI), properties = optionsWithoutPath) + locationUri = path.map(CatalogUtils.stringToURI), properties = optionsWithoutPath.toMap) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala index 095940772ae78..864130bbd87b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala @@ -80,7 +80,7 @@ object FilePartition extends Logging { currentFiles += file } closePartition() - partitions + partitions.toSeq } def maxSplitBytes( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala index 57082b40e1132..b5e276bd421a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala @@ -46,7 +46,7 @@ class HadoopFileLinesReader( def this(file: PartitionedFile, conf: Configuration) = this(file, None, conf) - private val iterator = { + private val _iterator = { val fileSplit = new FileSplit( new Path(new URI(file.filePath)), file.start, @@ -66,9 +66,9 @@ class HadoopFileLinesReader( new RecordReaderIterator(reader) } - override def hasNext: Boolean = iterator.hasNext + override def hasNext: Boolean = _iterator.hasNext - override def next(): Text = iterator.next() + override def next(): Text = _iterator.next() - override def close(): Unit = iterator.close() + override def close(): Unit = _iterator.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala index 0e6d803f02d4d..a48001f04a9bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala @@ -35,7 +35,7 @@ import org.apache.spark.input.WholeTextFileRecordReader */ class HadoopFileWholeTextReader(file: PartitionedFile, conf: Configuration) extends Iterator[Text] with Closeable { - private val iterator = { + private val _iterator = { val fileSplit = new CombineFileSplit( Array(new Path(new URI(file.filePath))), Array(file.start), @@ -50,9 +50,9 @@ class HadoopFileWholeTextReader(file: PartitionedFile, conf: Configuration) new RecordReaderIterator(reader) } - override def hasNext: Boolean = iterator.hasNext + override def hasNext: Boolean = _iterator.hasNext - override def next(): Text = iterator.next() + override def next(): Text = _iterator.next() - override def close(): Unit = iterator.close() + override def close(): Unit = _iterator.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 84160f35540df..a488ed16a835a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -133,7 +133,7 @@ class InMemoryFileIndex( } val filter = FileInputFormat.getInputPathFilter(new JobConf(hadoopConf, this.getClass)) val discovered = InMemoryFileIndex.bulkListLeafFiles( - pathsToFetch, hadoopConf, filter, sparkSession, areRootPaths = true) + pathsToFetch.toSeq, hadoopConf, filter, sparkSession, areRootPaths = true) discovered.foreach { case (path, leafFiles) => HiveCatalogMetrics.incrementFilesDiscovered(leafFiles.size) fileStatusCache.putLeafFiles(path, leafFiles.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 5846d46e146ed..4087efc486a4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -273,7 +273,7 @@ object PartitioningUtils { (None, Some(path)) } else { val (columnNames, values) = columns.reverse.unzip - (Some(PartitionValues(columnNames, values)), Some(currentPath)) + (Some(PartitionValues(columnNames.toSeq, values.toSeq)), Some(currentPath)) } } @@ -420,7 +420,7 @@ object PartitioningUtils { val distinctPartColNames = pathWithPartitionValues.map(_._2.columnNames).distinct def groupByKey[K, V](seq: Seq[(K, V)]): Map[K, Iterable[V]] = - seq.groupBy { case (key, _) => key }.mapValues(_.map { case (_, value) => value }) + seq.groupBy { case (key, _) => key }.mapValues(_.map { case (_, value) => value }).toMap val partColNamesToPaths = groupByKey(pathWithPartitionValues.map { case (path, partValues) => partValues.columnNames -> path diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 375cec597166c..cdac9d9c93925 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -150,21 +150,23 @@ object TextInputCSVDataSource extends CSVDataSource { inputPaths: Seq[FileStatus], options: CSVOptions): Dataset[String] = { val paths = inputPaths.map(_.getPath.toString) + val df = sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + className = classOf[TextFileFormat].getName, + options = options.parameters + ).resolveRelation(checkFilesExist = false)) + .select("value").as[String](Encoders.STRING) + if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - className = classOf[TextFileFormat].getName, - options = options.parameters - ).resolveRelation(checkFilesExist = false)) - .select("value").as[String](Encoders.STRING) + df } else { val charset = options.charset - val rdd = sparkSession.sparkContext - .hadoopFile[LongWritable, Text, TextInputFormat](paths.mkString(",")) - .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) - sparkSession.createDataset(rdd)(Encoders.STRING) + sparkSession.createDataset(df.queryExecution.toRdd.map { row => + val bytes = row.getBinary(0) + new String(bytes, 0, bytes.length, charset) + })(Encoders.STRING) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala index dfd84e344eb2a..719d72f5b9b52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala @@ -39,7 +39,7 @@ class JsonOutputWriter( case None => StandardCharsets.UTF_8 } - if (JSONOptionsInRead.blacklist.contains(encoding)) { + if (JSONOptionsInRead.denyList.contains(encoding)) { logWarning(s"The JSON file ($path) was written in the encoding ${encoding.displayName()}" + " which can be read back by Spark only if multiLine is enabled.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 39bbc60200b86..73910c3943e9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -68,14 +68,14 @@ class ParquetFilters( // When g is a `Map`, `g.getOriginalType` is `MAP`. // When g is a `List`, `g.getOriginalType` is `LIST`. case g: GroupType if g.getOriginalType == null => - getPrimitiveFields(g.getFields.asScala, parentFieldNames :+ g.getName) + getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName) // Parquet only supports push-down for primitive types; as a result, Map and List types // are removed. case _ => None } } - val primitiveFields = getPrimitiveFields(schema.getFields.asScala).map { field => + val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field => import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper (field.fieldNames.toSeq.quoted, field) } @@ -90,7 +90,7 @@ class ParquetFilters( .groupBy(_._1.toLowerCase(Locale.ROOT)) .filter(_._2.size == 1) .mapValues(_.head._2) - CaseInsensitiveMap(dedupPrimitiveFields) + CaseInsensitiveMap(dedupPrimitiveFields.toMap) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 8ce8a86d2f026..2eb205db8ccdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -79,7 +79,7 @@ class ParquetToSparkSchemaConverter( } } - StructType(fields) + StructType(fields.toSeq) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 95343e2872def..60cacda9f5f1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.CatalogV2Util.assertNoNullTypeInSchema import org.apache.spark.sql.connector.expressions.{FieldReference, RewritableTransform} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 @@ -292,6 +293,8 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi "in the table definition of " + table.identifier, sparkSession.sessionState.conf.caseSensitiveAnalysis) + assertNoNullTypeInSchema(schema) + val normalizedPartCols = normalizePartitionColumns(schema, table) val normalizedBucketSpec = normalizeBucketSpec(schema, table) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index e4e7887017a1d..c199df676ced3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -40,7 +40,7 @@ case class BatchScanExec( override def hashCode(): Int = batch.hashCode() - override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions() + @transient override lazy val partitions: Seq[InputPartition] = batch.planInputPartitions() override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index 211f61279ddd5..083c6bc7999bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan, SupportsReportPartitioning} -import org.apache.spark.sql.execution.LeafExecNode +import org.apache.spark.sql.execution.{ExplainUtils, LeafExecNode} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.connector.SupportsMetadata import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils @@ -43,7 +44,32 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { override def simpleString(maxFields: Int): String = { val result = s"$nodeName${truncatedString(output, "[", ", ", "]", maxFields)} ${scan.description()}" - Utils.redact(sqlContext.sessionState.conf.stringRedactionPattern, result) + redact(result) + } + + /** + * Shorthand for calling redact() without specifying redacting rules + */ + protected def redact(text: String): String = { + Utils.redact(sqlContext.sessionState.conf.stringRedactionPattern, text) + } + + override def verboseStringWithOperatorId(): String = { + val metaDataStr = scan match { + case s: SupportsMetadata => + s.getMetaData().toSeq.sorted.flatMap { + case (_, value) if value.isEmpty || value.equals("[]") => None + case (key, value) => Some(s"$key: ${redact(value)}") + case _ => None + } + case _ => + Seq(scan.description()) + } + s""" + |$formattedNodeName + |${ExplainUtils.generateFieldString("Output", output)} + |${metaDataStr.mkString("\n")} + |""".stripMargin } override def outputPartitioning: physical.Partitioning = scan match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index cca80c0cb6d57..f289a867e5ec0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, StagingTableCatalo import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, ProjectExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -218,18 +218,6 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil - case Repartition(1, false, child) => - val isContinuous = child.find { - case r: StreamingDataSourceV2Relation => r.stream.isInstanceOf[ContinuousStream] - case _ => false - }.isDefined - - if (isContinuous) { - ContinuousCoalesceExec(1, planLater(child)) :: Nil - } else { - Nil - } - case desc @ DescribeNamespace(ResolvedNamespace(catalog, ns), extended) => DescribeNamespaceExec(desc.output, catalog.asNamespaceCatalog, ns, extended) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala index b4a14c6face31..e273abf90e3bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala @@ -55,7 +55,7 @@ case class DescribeNamespaceExec( rows += toCatalystRow("Properties", properties.toSeq.mkString("(", ",", ")")) } } - rows + rows.toSeq } private def toCatalystRow(strs: String*): InternalRow = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala index bc6bb175f979e..81b1c81499c74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala @@ -43,7 +43,7 @@ case class DescribeTableExec( if (isExtended) { addTableDetails(rows) } - rows + rows.toSeq } private def addTableDetails(rows: ArrayBuffer[InternalRow]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 7e8e0ed2dc675..f090d7861b629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -29,11 +29,13 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan, Statistics, SupportsReportStatistics} import org.apache.spark.sql.execution.PartitionedFileUtil import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.internal.connector.SupportsMetadata import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -trait FileScan extends Scan with Batch with SupportsReportStatistics with Logging { +trait FileScan extends Scan + with Batch with SupportsReportStatistics with SupportsMetadata with Logging { /** * Returns whether a file with `path` could be split or not. */ @@ -93,23 +95,28 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin override def hashCode(): Int = getClass.hashCode() + val maxMetadataValueLength = 100 + override def description(): String = { - val maxMetadataValueLength = 100 + val metadataStr = getMetaData().toSeq.sorted.map { + case (key, value) => + val redactedValue = + Utils.redact(sparkSession.sessionState.conf.stringRedactionPattern, value) + key + ": " + StringUtils.abbreviate(redactedValue, maxMetadataValueLength) + }.mkString(", ") + s"${this.getClass.getSimpleName} $metadataStr" + } + + override def getMetaData(): Map[String, String] = { val locationDesc = fileIndex.getClass.getSimpleName + Utils.buildLocationMetadata(fileIndex.rootPaths, maxMetadataValueLength) - val metadata: Map[String, String] = Map( + Map( + "Format" -> s"${this.getClass.getSimpleName.replace("Scan", "").toLowerCase(Locale.ROOT)}", "ReadSchema" -> readDataSchema.catalogString, "PartitionFilters" -> seqToString(partitionFilters), "DataFilters" -> seqToString(dataFilters), "Location" -> locationDesc) - val metadataStr = metadata.toSeq.sorted.map { - case (key, value) => - val redactedValue = - Utils.redact(sparkSession.sessionState.conf.stringRedactionPattern, value) - key + ": " + StringUtils.abbreviate(redactedValue, maxMetadataValueLength) - }.mkString(", ") - s"${this.getClass.getSimpleName} $metadataStr" } protected def partitions: Seq[FilePartition] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 1a6f03f54f2e9..7f6ae20d5cd0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -63,7 +63,7 @@ object PushDownUtils extends PredicateHelper { val postScanFilters = r.pushFilters(translatedFilters.toArray).map { filter => DataSourceStrategy.rebuildExpressionFromFilter(filter, translatedFilterToExpr) } - (r.pushedFilters(), untranslatableExprs ++ postScanFilters) + (r.pushedFilters(), (untranslatableExprs ++ postScanFilters).toSeq) case _ => (Nil, filters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala index 9188f4eb60d56..ceeed0f840700 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowNamespacesExec.scala @@ -52,6 +52,6 @@ case class ShowNamespacesExec( } } - rows + rows.toSeq } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala index 820f5ae8f1b12..5ba01deae9513 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala @@ -49,6 +49,6 @@ case class ShowTablesExec( } } - rows + rows.toSeq } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index 2ed33b867183b..df3f231f7d0ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -288,7 +288,7 @@ private[sql] object V2SessionCatalog { s"SessionCatalog does not support partition transform: $transform") } - (identityCols, bucketSpec) + (identityCols.toSeq, bucketSpec) } private def toCatalogDatabase( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index 4f510322815ef..efb21e1c1e597 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -107,4 +107,8 @@ case class CSVScan( override def description(): String = { super.description() + ", PushedFilters: " + pushedFilters.mkString("[", ", ", "]") } + + override def getMetaData(): Map[String, String] = { + super.getMetaData() ++ Map("PushedFilers" -> seqToString(pushedFilters)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 62894fa7a2538..38b8ced51a141 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -65,6 +65,10 @@ case class OrcScan( super.description() + ", PushedFilters: " + seqToString(pushedFilters) } + override def getMetaData(): Map[String, String] = { + super.getMetaData() ++ Map("PushedFilers" -> seqToString(pushedFilters)) + } + override def withFilters( partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index bb315262a8211..c9c1e28a36960 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -93,6 +93,10 @@ case class ParquetScan( super.description() + ", PushedFilters: " + seqToString(pushedFilters) } + override def getMetaData(): Map[String, String] = { + super.getMetaData() ++ Map("PushedFilers" -> seqToString(pushedFilters)) + } + override def withFilters( partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 3242ac21ab324..186bac6f43332 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -162,7 +162,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { return (leftKeys, rightKeys) } } - (leftKeysBuffer, rightKeysBuffer) + (leftKeysBuffer.toSeq, rightKeysBuffer.toSeq) } private def reorderJoinKeys( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index fcbd0b19515b1..dadf1129c34b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -103,11 +103,11 @@ case class AggregateInPandasExec( // Schema of input rows to the python runner val aggInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => StructField(s"_$i", dt) - }) + }.toSeq) // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { - val prunedProj = UnsafeProjection.create(allInputs, child.output) + val prunedProj = UnsafeProjection.create(allInputs.toSeq, child.output) val grouped = if (groupingExpressions.isEmpty) { // Use an empty unsafe row as a place holder for the grouping key diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index 96e3bb721a822..298d63478b63e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -114,10 +114,10 @@ trait EvalPythonExec extends UnaryExecNode { } }.toArray }.toArray - val projection = MutableProjection.create(allInputs, child.output) + val projection = MutableProjection.create(allInputs.toSeq, child.output) val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => StructField(s"_$i", dt) - }) + }.toSeq) // Add rows to queue to join later with the result. val projectedRowIter = iter.map { inputRow => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 520afad287648..7fe3263630820 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -198,7 +198,7 @@ object EvaluatePython { case udt: UserDefinedType[_] => makeFromJava(udt.sqlType) - case other => (obj: Any) => nullSafeConvert(other)(PartialFunction.empty) + case other => (obj: Any) => nullSafeConvert(obj)(PartialFunction.empty) } private def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 7bc8b95cfb03b..1c88056cb50c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -72,7 +72,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { } } // There is no Python UDF over aggregate expression - Project(projList, agg.copy(aggregateExpressions = aggExpr)) + Project(projList.toSeq, agg.copy(aggregateExpressions = aggExpr.toSeq)) } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { @@ -134,9 +134,9 @@ object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] { }.asInstanceOf[NamedExpression] } agg.copy( - groupingExpressions = groupingExpr, + groupingExpressions = groupingExpr.toSeq, aggregateExpressions = aggExpr, - child = Project(projList ++ agg.child.output, agg.child)) + child = Project((projList ++ agg.child.output).toSeq, agg.child)) } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala index 68ce991a8ae7f..2da0000dad4ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala @@ -118,6 +118,6 @@ private[python] object PandasGroupUtils { // Attributes after deduplication val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes - (dedupAttributes, argOffsets) + (dedupAttributes.toSeq, argOffsets) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index e8ae0eaf0ea48..29537cc0e573f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -347,7 +347,7 @@ object CompactibleFileStreamLog { } else if (defaultInterval < (latestCompactBatchId + 1) / 2) { // Find the first divisor >= default compact interval def properDivisors(min: Int, n: Int) = - (min to n/2).view.filter(i => n % i == 0) :+ n + (min to n/2).view.filter(i => n % i == 0).toSeq :+ n properDivisors(defaultInterval, latestCompactBatchId + 1).head } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala index 985a5fa6063ef..11bdfee460e66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution.streaming +import scala.collection.mutable + import org.apache.spark.SparkEnv import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -33,7 +35,7 @@ case class GetRecord(offset: ContinuousRecordPartitionOffset) * to the number of partitions. * @param lock a lock object for locking the buckets for read */ -class ContinuousRecordEndpoint(buckets: Seq[Seq[UnsafeRow]], lock: Object) +class ContinuousRecordEndpoint(buckets: Seq[mutable.Seq[UnsafeRow]], lock: Object) extends ThreadSafeRpcEndpoint { private var startOffsets: Seq[Int] = List.fill(buckets.size)(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index e8ce8e1487093..f2557696485b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -482,7 +482,7 @@ object FileStreamSource { } private def buildSourceGlobFilters(sourcePath: Path): Seq[GlobFilter] = { - val filters = new scala.collection.mutable.MutableList[GlobFilter]() + val filters = new scala.collection.mutable.ArrayBuffer[GlobFilter]() var currentPath = sourcePath while (!currentPath.isRoot) { @@ -490,7 +490,7 @@ object FileStreamSource { currentPath = currentPath.getParent } - filters.toList + filters.toSeq } override protected def cleanTask(entry: FileEntry): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala index f6cc8116c6c4c..de8a8cd7d3b58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala @@ -139,7 +139,7 @@ class ManifestFileCommitProtocol(jobId: String, path: String) if (addedFiles.nonEmpty) { val fs = new Path(addedFiles.head).getFileSystem(taskContext.getConfiguration) val statuses: Seq[SinkFileStatus] = - addedFiles.map(f => SinkFileStatus(fs.getFileStatus(new Path(f)))) + addedFiles.map(f => SinkFileStatus(fs.getFileStatus(new Path(f)))).toSeq new TaskCommitMessage(statuses) } else { new TaskCommitMessage(Seq.empty[SinkFileStatus]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 2c737206dd2d9..fe3f0e95b383c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -127,8 +127,8 @@ trait ProgressReporter extends Logging { * `committedOffsets` in `StreamExecution` to make sure that the correct range is recorded. */ protected def recordTriggerOffsets(from: StreamProgress, to: StreamProgress): Unit = { - currentTriggerStartOffsets = from.mapValues(_.json) - currentTriggerEndOffsets = to.mapValues(_.json) + currentTriggerStartOffsets = from.mapValues(_.json).toMap + currentTriggerEndOffsets = to.mapValues(_.json).toMap } private def updateProgress(newProgress: StreamingQueryProgress): Unit = { @@ -192,7 +192,8 @@ trait ProgressReporter extends Logging { timestamp = formatTimestamp(currentTriggerStartTimestamp), batchId = currentBatchId, batchDuration = processingTimeMills, - durationMs = new java.util.HashMap(currentDurationsMs.toMap.mapValues(long2Long).asJava), + durationMs = + new java.util.HashMap(currentDurationsMs.toMap.mapValues(long2Long).toMap.asJava), eventTime = new java.util.HashMap(executionStats.eventTimeStats.asJava), stateOperators = executionStats.stateOperators.toArray, sources = sourceProgress.toArray, @@ -255,14 +256,14 @@ trait ProgressReporter extends Logging { "avg" -> stats.avg.toLong).mapValues(formatTimestamp) }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp - ExecutionStats(numInputRows, stateOperators, eventTimeStats) + ExecutionStats(numInputRows, stateOperators, eventTimeStats.toMap) } /** Extract number of input sources for each streaming source in plan */ private def extractSourceToNumInputRows(): Map[SparkDataStream, Long] = { def sumRows(tuples: Seq[(SparkDataStream, Long)]): Map[SparkDataStream, Long] = { - tuples.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source + tuples.groupBy(_._1).mapValues(_.map(_._2).sum).toMap // sum up rows for each source } val onlyDataSourceV2Sources = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index dc5fc2e43143d..3d071df493cec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -295,6 +295,10 @@ case class StreamingSymmetricHashJoinExec( postJoinFilter(joinedRow.withLeft(leftKeyValue.value).withRight(rightValue)) } } + + // NOTE: we need to make sure `outerOutputIter` is evaluated "after" exhausting all of + // elements in `innerOutputIter`, because evaluation of `innerOutputIter` may update + // the match flag which the logic for outer join is relying on. val removedRowIter = leftSideJoiner.removeOldState() val outerOutputIter = removedRowIter.filterNot { kv => stateFormatVersion match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala deleted file mode 100644 index 4c621890c9793..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.continuous - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} - -/** - * Physical plan for coalescing a continuous processing plan. - * - * Currently, only coalesces to a single partition are supported. `numPartitions` must be 1. - */ -case class ContinuousCoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecNode { - override def output: Seq[Attribute] = child.output - - override def outputPartitioning: Partitioning = SinglePartition - - override def doExecute(): RDD[InternalRow] = { - assert(numPartitions == 1) - new ContinuousCoalesceRDD( - sparkContext, - numPartitions, - conf.continuousStreamingExecutorQueueSize, - sparkContext.getLocalProperty(ContinuousExecution.EPOCH_INTERVAL_KEY).toLong, - child.execute()) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala deleted file mode 100644 index 14046f6a99c24..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.continuous - -import java.util.UUID - -import org.apache.spark._ -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.streaming.continuous.shuffle._ -import org.apache.spark.util.ThreadUtils - -case class ContinuousCoalesceRDDPartition( - index: Int, - endpointName: String, - queueSize: Int, - numShuffleWriters: Int, - epochIntervalMs: Long) - extends Partition { - // Initialized only on the executor, and only once even as we call compute() multiple times. - lazy val (reader: ContinuousShuffleReader, endpoint) = { - val env = SparkEnv.get.rpcEnv - val receiver = new RPCContinuousShuffleReader( - queueSize, numShuffleWriters, epochIntervalMs, env) - val endpoint = env.setupEndpoint(endpointName, receiver) - - TaskContext.get().addTaskCompletionListener[Unit] { ctx => - env.stop(endpoint) - } - (receiver, endpoint) - } - // This flag will be flipped on the executors to indicate that the threads processing - // partitions of the write-side RDD have been started. These will run indefinitely - // asynchronously as epochs of the coalesce RDD complete on the read side. - private[continuous] var writersInitialized: Boolean = false -} - -/** - * RDD for continuous coalescing. Asynchronously writes all partitions of `prev` into a local - * continuous shuffle, and then reads them in the task thread using `reader`. - */ -class ContinuousCoalesceRDD( - context: SparkContext, - numPartitions: Int, - readerQueueSize: Int, - epochIntervalMs: Long, - prev: RDD[InternalRow]) - extends RDD[InternalRow](context, Nil) { - - // When we support more than 1 target partition, we'll need to figure out how to pass in the - // required partitioner. - private val outputPartitioner = new HashPartitioner(1) - - private val readerEndpointNames = (0 until numPartitions).map { i => - s"ContinuousCoalesceRDD-part$i-${UUID.randomUUID()}" - } - - override def getPartitions: Array[Partition] = { - (0 until numPartitions).map { partIndex => - ContinuousCoalesceRDDPartition( - partIndex, - readerEndpointNames(partIndex), - readerQueueSize, - prev.getNumPartitions, - epochIntervalMs) - }.toArray - } - - private lazy val threadPool = ThreadUtils.newDaemonFixedThreadPool( - prev.getNumPartitions, - this.name) - - override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { - val part = split.asInstanceOf[ContinuousCoalesceRDDPartition] - - if (!part.writersInitialized) { - val rpcEnv = SparkEnv.get.rpcEnv - - // trigger lazy initialization - part.endpoint - val endpointRefs = readerEndpointNames.map { endpointName => - rpcEnv.setupEndpointRef(rpcEnv.address, endpointName) - } - - val runnables = prev.partitions.map { prevSplit => - new Runnable() { - override def run(): Unit = { - TaskContext.setTaskContext(context) - - val writer: ContinuousShuffleWriter = new RPCContinuousShuffleWriter( - prevSplit.index, outputPartitioner, endpointRefs.toArray) - - EpochTracker.initializeCurrentEpoch( - context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) - while (!context.isInterrupted() && !context.isCompleted()) { - writer.write(prev.compute(prevSplit, context).asInstanceOf[Iterator[UnsafeRow]]) - // Note that current epoch is a inheritable thread local but makes another instance, - // so each writer thread can properly increment its own epoch without affecting - // the main task thread. - EpochTracker.incrementCurrentEpoch() - } - } - } - } - - context.addTaskCompletionListener[Unit] { ctx => - threadPool.shutdownNow() - } - - part.writersInitialized = true - - runnables.foreach(threadPool.execute) - } - - part.reader.read() - } - - override def clearDependencies(): Unit = { - throw new IllegalStateException("Continuous RDDs cannot be checkpointed") - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index a109c2171f3d2..d225e65aabe11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -206,9 +206,6 @@ class ContinuousExecution( currentEpochCoordinatorId = epochCoordinatorId sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId) - sparkSessionForQuery.sparkContext.setLocalProperty( - ContinuousExecution.EPOCH_INTERVAL_KEY, - trigger.asInstanceOf[ContinuousTrigger].intervalMs.toString) // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( @@ -436,5 +433,4 @@ class ContinuousExecution( object ContinuousExecution { val START_EPOCH_KEY = "__continuous_start_epoch" val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id" - val EPOCH_INTERVAL_KEY = "__continuous_epoch_interval" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala deleted file mode 100644 index 9b13f6398d837..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.continuous.shuffle - -import java.util.UUID - -import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} -import org.apache.spark.rdd.RDD -import org.apache.spark.rpc.RpcAddress -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.NextIterator - -case class ContinuousShuffleReadPartition( - index: Int, - endpointName: String, - queueSize: Int, - numShuffleWriters: Int, - epochIntervalMs: Long) - extends Partition { - // Initialized only on the executor, and only once even as we call compute() multiple times. - lazy val (reader: ContinuousShuffleReader, endpoint) = { - val env = SparkEnv.get.rpcEnv - val receiver = new RPCContinuousShuffleReader( - queueSize, numShuffleWriters, epochIntervalMs, env) - val endpoint = env.setupEndpoint(endpointName, receiver) - - TaskContext.get().addTaskCompletionListener[Unit] { ctx => - env.stop(endpoint) - } - (receiver, endpoint) - } -} - -/** - * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their - * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks - * poll from their receiver until an epoch marker is sent. - * - * @param sc the RDD context - * @param numPartitions the number of read partitions for this RDD - * @param queueSize the size of the row buffers to use - * @param numShuffleWriters the number of continuous shuffle writers feeding into this RDD - * @param epochIntervalMs the checkpoint interval of the streaming query - */ -class ContinuousShuffleReadRDD( - sc: SparkContext, - numPartitions: Int, - queueSize: Int = 1024, - numShuffleWriters: Int = 1, - epochIntervalMs: Long = 1000, - val endpointNames: Seq[String] = Seq(s"RPCContinuousShuffleReader-${UUID.randomUUID()}")) - extends RDD[UnsafeRow](sc, Nil) { - - override protected def getPartitions: Array[Partition] = { - (0 until numPartitions).map { partIndex => - ContinuousShuffleReadPartition( - partIndex, endpointNames(partIndex), queueSize, numShuffleWriters, epochIntervalMs) - }.toArray - } - - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - split.asInstanceOf[ContinuousShuffleReadPartition].reader.read() - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala deleted file mode 100644 index 502ae0d4822e8..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.continuous.shuffle - -import java.util.concurrent._ -import java.util.concurrent.atomic.AtomicBoolean - -import org.apache.spark.internal.Logging -import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.util.NextIterator - -/** - * Messages for the RPCContinuousShuffleReader endpoint. Either an incoming row or an epoch marker. - * - * Each message comes tagged with writerId, identifying which writer the message is coming - * from. The receiver will only begin the next epoch once all writers have sent an epoch - * marker ending the current epoch. - */ -private[shuffle] sealed trait RPCContinuousShuffleMessage extends Serializable { - def writerId: Int -} -private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow) - extends RPCContinuousShuffleMessage -private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContinuousShuffleMessage - -/** - * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle - * writers will send rows here, with continuous shuffle readers polling for new rows as needed. - * - * TODO: Support multiple source tasks. We need to output a single epoch marker once all - * source tasks have sent one. - */ -private[continuous] class RPCContinuousShuffleReader( - queueSize: Int, - numShuffleWriters: Int, - epochIntervalMs: Long, - override val rpcEnv: RpcEnv) - extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { - // Note that this queue will be drained from the main task thread and populated in the RPC - // response thread. - private val queues = Array.fill(numShuffleWriters) { - new ArrayBlockingQueue[RPCContinuousShuffleMessage](queueSize) - } - - // Exposed for testing to determine if the endpoint gets stopped on task end. - private[shuffle] val stopped = new AtomicBoolean(false) - - override def onStop(): Unit = { - stopped.set(true) - } - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case r: RPCContinuousShuffleMessage => - // Note that this will block a thread the shared RPC handler pool! - // The TCP based shuffle handler (SPARK-24541) will avoid this problem. - queues(r.writerId).put(r) - context.reply(()) - } - - override def read(): Iterator[UnsafeRow] = { - new NextIterator[UnsafeRow] { - // An array of flags for whether each writer ID has gotten an epoch marker. - private val writerEpochMarkersReceived = Array.fill(numShuffleWriters)(false) - - private val executor = Executors.newFixedThreadPool(numShuffleWriters) - private val completion = new ExecutorCompletionService[RPCContinuousShuffleMessage](executor) - - private def completionTask(writerId: Int) = new Callable[RPCContinuousShuffleMessage] { - override def call(): RPCContinuousShuffleMessage = queues(writerId).take() - } - - // Initialize by submitting tasks to read the first row from each writer. - (0 until numShuffleWriters).foreach(writerId => completion.submit(completionTask(writerId))) - - /** - * In each call to getNext(), we pull the next row available in the completion queue, and then - * submit another task to read the next row from the writer which returned it. - * - * When a writer sends an epoch marker, we note that it's finished and don't submit another - * task for it in this epoch. The iterator is over once all writers have sent an epoch marker. - */ - override def getNext(): UnsafeRow = { - var nextRow: UnsafeRow = null - while (!finished && nextRow == null) { - completion.poll(epochIntervalMs, TimeUnit.MILLISECONDS) match { - case null => - // Try again if the poll didn't wait long enough to get a real result. - // But we should be getting at least an epoch marker every checkpoint interval. - val writerIdsUncommitted = writerEpochMarkersReceived.zipWithIndex.collect { - case (flag, idx) if !flag => idx - } - logWarning( - s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " + - s"for writers ${writerIdsUncommitted.mkString(",")} to send epoch markers.") - - // The completion service guarantees this future will be available immediately. - case future => future.get() match { - case ReceiverRow(writerId, r) => - // Start reading the next element in the queue we just took from. - completion.submit(completionTask(writerId)) - nextRow = r - case ReceiverEpochMarker(writerId) => - // Don't read any more from this queue. If all the writers have sent epoch markers, - // the epoch is over; otherwise we need to loop again to poll from the remaining - // writers. - writerEpochMarkersReceived(writerId) = true - if (writerEpochMarkersReceived.forall(_ == true)) { - finished = true - } - } - } - } - - nextRow - } - - override def close(): Unit = { - executor.shutdownNow() - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala deleted file mode 100644 index 1c6f3ddb395e6..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.continuous.shuffle - -import scala.concurrent.Future -import scala.concurrent.duration.Duration - -import org.apache.spark.Partitioner -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.util.ThreadUtils - -/** - * A [[ContinuousShuffleWriter]] sending data to [[RPCContinuousShuffleReader]] instances. - * - * @param writerId The partition ID of this writer. - * @param outputPartitioner The partitioner on the reader side of the shuffle. - * @param endpoints The [[RPCContinuousShuffleReader]] endpoints to write to. Indexed by - * partition ID within outputPartitioner. - */ -class RPCContinuousShuffleWriter( - writerId: Int, - outputPartitioner: Partitioner, - endpoints: Array[RpcEndpointRef]) extends ContinuousShuffleWriter { - - if (outputPartitioner.numPartitions != 1) { - throw new IllegalArgumentException("multiple readers not yet supported") - } - - if (outputPartitioner.numPartitions != endpoints.length) { - throw new IllegalArgumentException(s"partitioner size ${outputPartitioner.numPartitions} did " + - s"not match endpoint count ${endpoints.length}") - } - - def write(epoch: Iterator[UnsafeRow]): Unit = { - while (epoch.hasNext) { - val row = epoch.next() - endpoints(outputPartitioner.getPartition(row)).askSync[Unit](ReceiverRow(writerId, row)) - } - - val futures = endpoints.map(_.ask[Unit](ReceiverEpochMarker(writerId))).toSeq - implicit val ec = ThreadUtils.sameThread - ThreadUtils.awaitResult(Future.sequence(futures), Duration.Inf) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index e5b9e68d71026..9adb9af7318d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -215,7 +215,7 @@ case class MemoryStream[A : Encoder]( batches.slice(sliceStart, sliceEnd) } - logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) + logDebug(generateDebugString(newBlocks.flatten.toSeq, startOrdinal, endOrdinal)) numPartitions match { case Some(numParts) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala index 03ebbb9f1b376..24ff9c2e8384d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala @@ -80,7 +80,7 @@ class MemorySink extends Table with SupportsWrite with Logging { /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { - batches.flatMap(_.data) + batches.flatMap(_.data).toSeq } def latestBatchId: Option[Long] = synchronized { @@ -92,7 +92,7 @@ class MemorySink extends Table with SupportsWrite with Logging { } def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { - batches.filter(_.batchId > sinceBatchId).flatMap(_.data) + batches.filter(_.batchId > sinceBatchId).flatMap(_.data).toSeq } def toDebugString: String = synchronized { @@ -183,7 +183,7 @@ class MemoryDataWriter(partition: Int, schema: StructType) } override def commit(): MemoryWriterCommitMessage = { - val msg = MemoryWriterCommitMessage(partition, data.clone()) + val msg = MemoryWriterCommitMessage(partition, data.clone().toSeq) data.clear() msg } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 0eb3dce1bbd27..90a53727aa317 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -74,20 +74,8 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( StateStoreId(checkpointLocation, operatorId, partition.index), queryRunId) - // If we're in continuous processing mode, we should get the store version for the current - // epoch rather than the one at planning time. - val isContinuous = Option(ctxt.getLocalProperty(StreamExecution.IS_CONTINUOUS_PROCESSING)) - .map(_.toBoolean).getOrElse(false) - val currentVersion = if (isContinuous) { - val epoch = EpochTracker.getCurrentEpoch - assert(epoch.isDefined, "Current epoch must be defined for continuous processing streams.") - epoch.get - } else { - storeVersion - } - store = StateStore.get( - storeProviderId, keySchema, valueSchema, indexOrdinal, currentVersion, + storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion, storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeUpdateFunction(store, inputIter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 1a0a43c083879..1a5b50dcc7901 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -451,10 +451,25 @@ class SymmetricHashJoinStateManager( } private trait KeyWithIndexToValueRowConverter { + /** Defines the schema of the value row (the value side of K-V in state store). */ def valueAttributes: Seq[Attribute] + /** + * Convert the value row to (actual value, match) pair. + * + * NOTE: implementations should ensure the result row is NOT reused during execution, so + * that caller can safely read the value in any time. + */ def convertValue(value: UnsafeRow): ValueAndMatchPair + /** + * Build the value row from (actual value, match) pair. This is expected to be called just + * before storing to the state store. + * + * NOTE: depending on the implementation, the result row "may" be reused during execution + * (to avoid initialization of object), so the caller should ensure that the logic doesn't + * affect by such behavior. Call copy() against the result row if needed. + */ def convertToValueRow(value: UnsafeRow, matched: Boolean): UnsafeRow } @@ -493,7 +508,7 @@ class SymmetricHashJoinStateManager( override def convertValue(value: UnsafeRow): ValueAndMatchPair = { if (value != null) { - ValueAndMatchPair(valueRowGenerator(value), + ValueAndMatchPair(valueRowGenerator(value).copy(), value.getBoolean(indexOrdinalInValueWithMatchedRow)) } else { null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index a9c01e69b9b13..497b13793a67b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -97,7 +97,7 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => .map(entry => entry._1 -> longMetric(entry._1).value) val javaConvertedCustomMetrics: java.util.HashMap[String, java.lang.Long] = - new java.util.HashMap(customMetrics.mapValues(long2Long).asJava) + new java.util.HashMap(customMetrics.mapValues(long2Long).toMap.asJava) new StateOperatorProgress( numRowsTotal = longMetric("numTotalStateRows").value, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 33539c01ee5dd..ff229c2bea7ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -57,7 +57,7 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L if (running.nonEmpty) { val runningPageTable = - executionsTable(request, "running", running, currentTime, true, true, true) + executionsTable(request, "running", running.toSeq, currentTime, true, true, true) _content ++= +-- !query output +1.0 1.2 0.1 0.1 + + +-- !query +select -1F, -1.2F, -.10F, -0.10F +-- !query schema +struct<-1.0:float,-1.2:float,-0.1:float,-0.1:float> +-- !query output +-1.0 -1.2 -0.1 -0.1 + + +-- !query +select -3.4028235E39f +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException + +Numeric literal -3.4028235E39 does not fit in range [-3.4028234663852886E+38, 3.4028234663852886E+38] for type float(line 1, pos 7) + +== SQL == +select -3.4028235E39f +-------^^^ + + -- !query select 1D, 1.2D, 1e10, 1.5e5, .10D, 0.10D, .1e5, .9e+2, 0.9e+2, 900e-1, 9.e+1 -- !query schema @@ -216,6 +246,14 @@ struct<0.3:decimal(1,1),-0.8:decimal(1,1),0.5:decimal(1,1),-0.18:decimal(2,2),0. 0.3 -0.8 0.5 -0.18 0.1111 0.1111 +-- !query +select 0.3 F, 0.4 D, 0.5 BD +-- !query schema +struct +-- !query output +0.3 0.4 0.5 + + -- !query select 123456789012345678901234567890123456789e10d, 123456789012345678901234567890123456789.1e10d -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index f6720f6c5faa4..ea74bb7175e96 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 50 +-- Number of queries: 54 -- !query @@ -164,6 +164,36 @@ decimal can only support precision up to 38 select 1234567890123456789012345678901234567890.0 +-- !query +select 1F, 1.2F, .10f, 0.10f +-- !query schema +struct<1.0:float,1.2:float,0.1:float,0.1:float> +-- !query output +1.0 1.2 0.1 0.1 + + +-- !query +select -1F, -1.2F, -.10F, -0.10F +-- !query schema +struct<-1.0:float,-1.2:float,-0.1:float,-0.1:float> +-- !query output +-1.0 -1.2 -0.1 -0.1 + + +-- !query +select -3.4028235E39f +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException + +Numeric literal -3.4028235E39 does not fit in range [-3.4028234663852886E+38, 3.4028234663852886E+38] for type float(line 1, pos 7) + +== SQL == +select -3.4028235E39f +-------^^^ + + -- !query select 1D, 1.2D, 1e10, 1.5e5, .10D, 0.10D, .1e5, .9e+2, 0.9e+2, 900e-1, 9.e+1 -- !query schema @@ -216,6 +246,14 @@ struct<0.3:decimal(1,1),-0.8:decimal(1,1),0.5:decimal(1,1),-0.18:decimal(2,2),0. 0.3 -0.8 0.5 -0.18 0.1111 0.1111 +-- !query +select 0.3 F, 0.4 D, 0.5 BD +-- !query schema +struct +-- !query output +0.3 0.4 0.5 + + -- !query select 123456789012345678901234567890123456789e10d, 123456789012345678901234567890123456789.1e10d -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index fa06484a73d95..131ab1b94f59e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -923,4 +923,503 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { val inSet = InSet(Literal("a"), Set("a", "b").map(UTF8String.fromString)) assert(inSet.sql === "('a' IN ('a', 'b'))") } + + def checkAnswerAndSchema( + df: => DataFrame, + expectedAnswer: Seq[Row], + expectedSchema: StructType): Unit = { + + checkAnswer(df, expectedAnswer) + assert(df.schema == expectedSchema) + } + + private lazy val structType = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false))) + + private lazy val structLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, null, 3)) :: Nil), + StructType(Seq(StructField("a", structType, nullable = false)))) + + private lazy val nullStructLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(null) :: Nil), + StructType(Seq(StructField("a", structType, nullable = true)))) + + private lazy val structLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, null, 3))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = false))), + nullable = false)))) + + private lazy val nullStructLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(null)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = true))), + nullable = false)))) + + private lazy val structLevel3: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(Row(1, null, 3)))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = false))), + nullable = false))), + nullable = false)))) + + test("withField should throw an exception if called on a non-StructType column") { + intercept[AnalysisException] { + testData.withColumn("key", $"key".withField("a", lit(2))) + }.getMessage should include("struct argument should be struct type, got: int") + } + + test("withField should throw an exception if either fieldName or col argument are null") { + intercept[IllegalArgumentException] { + structLevel1.withColumn("a", $"a".withField(null, lit(2))) + }.getMessage should include("fieldName cannot be null") + + intercept[IllegalArgumentException] { + structLevel1.withColumn("a", $"a".withField("b", null)) + }.getMessage should include("col cannot be null") + + intercept[IllegalArgumentException] { + structLevel1.withColumn("a", $"a".withField(null, null)) + }.getMessage should include("fieldName cannot be null") + } + + test("withField should throw an exception if any intermediate structs don't exist") { + intercept[AnalysisException] { + structLevel2.withColumn("a", 'a.withField("x.b", lit(2))) + }.getMessage should include("No such struct field x in a") + + intercept[AnalysisException] { + structLevel3.withColumn("a", 'a.withField("a.x.b", lit(2))) + }.getMessage should include("No such struct field x in a") + } + + test("withField should throw an exception if intermediate field is not a struct") { + intercept[AnalysisException] { + structLevel1.withColumn("a", 'a.withField("b.a", lit(2))) + }.getMessage should include("struct argument should be struct type, got: int") + } + + test("withField should throw an exception if intermediate field reference is ambiguous") { + intercept[AnalysisException] { + val structLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = false), + StructField("a", structType, nullable = false))), + nullable = false)))) + + structLevel2.withColumn("a", 'a.withField("a.b", lit(2))) + }.getMessage should include("Ambiguous reference to fields") + } + + test("withField should add field with no name") { + checkAnswerAndSchema( + structLevel1.withColumn("a", $"a".withField("", lit(4))), + Row(Row(1, null, 3, 4)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should add field to struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("d", lit(4))), + Row(Row(1, null, 3, 4)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should add field to null struct") { + checkAnswerAndSchema( + nullStructLevel1.withColumn("a", $"a".withField("d", lit(4))), + Row(null) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = true)))) + } + + test("withField should add field to nested null struct") { + checkAnswerAndSchema( + nullStructLevel2.withColumn("a", $"a".withField("a.d", lit(4))), + Row(Row(null)) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = true))), + nullable = false)))) + } + + test("withField should add null field to struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("d", lit(null).cast(IntegerType))), + Row(Row(1, null, 3, null)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = true))), + nullable = false)))) + } + + test("withField should add multiple fields to struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))), + Row(Row(1, null, 3, 4, 5)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false), + StructField("e", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should add field to nested struct") { + Seq( + structLevel2.withColumn("a", 'a.withField("a.d", lit(4))), + structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("d", lit(4)))) + ).foreach { df => + checkAnswerAndSchema( + df, + Row(Row(Row(1, null, 3, 4))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + } + + test("withField should add field to deeply nested struct") { + checkAnswerAndSchema( + structLevel3.withColumn("a", 'a.withField("a.a.d", lit(4))), + Row(Row(Row(Row(1, null, 3, 4)))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("withField should replace field in struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("b", lit(2))), + Row(Row(1, 2, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should replace field in null struct") { + checkAnswerAndSchema( + nullStructLevel1.withColumn("a", 'a.withField("b", lit("foo"))), + Row(null) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true)))) + } + + test("withField should replace field in nested null struct") { + checkAnswerAndSchema( + nullStructLevel2.withColumn("a", $"a".withField("a.b", lit("foo"))), + Row(Row(null)) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true))), + nullable = false)))) + } + + test("withField should replace field with null value in struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("c", lit(null).cast(IntegerType))), + Row(Row(1, null, null)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = true))), + nullable = false)))) + } + + test("withField should replace multiple fields in struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))), + Row(Row(10, 20, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should replace field in nested struct") { + Seq( + structLevel2.withColumn("a", $"a".withField("a.b", lit(2))), + structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("b", lit(2)))) + ).foreach { df => + checkAnswerAndSchema( + df, + Row(Row(Row(1, 2, 3))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + } + + test("withField should replace field in deeply nested struct") { + checkAnswerAndSchema( + structLevel3.withColumn("a", $"a".withField("a.a.b", lit(2))), + Row(Row(Row(Row(1, 2, 3)))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("withField should replace all fields with given name in struct") { + val structLevel1 = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, 2, 3)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("b", lit(100))), + Row(Row(1, 100, 100)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should replace fields in struct in given order") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("b", lit(2)).withField("b", lit(20))), + Row(Row(1, 20, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should add field and then replace same field in struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("d", lit(5))), + Row(Row(1, null, 3, 5)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false))), + nullable = false)))) + } + + test("withField should handle fields with dots in their name if correctly quoted") { + val df: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, null, 3))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a.b", StructType(Seq( + StructField("c.d", IntegerType, nullable = false), + StructField("e.f", IntegerType, nullable = true), + StructField("g.h", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + df.withColumn("a", 'a.withField("`a.b`.`e.f`", lit(2))), + Row(Row(Row(1, 2, 3))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a.b", StructType(Seq( + StructField("c.d", IntegerType, nullable = false), + StructField("e.f", IntegerType, nullable = false), + StructField("g.h", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + intercept[AnalysisException] { + df.withColumn("a", 'a.withField("a.b.e.f", lit(2))) + }.getMessage should include("No such struct field a in a.b") + } + + private lazy val mixedCaseStructLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, 1)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false))), + nullable = false)))) + + test("withField should replace field in struct even if casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))), + Row(Row(2, 1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("A", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))), + Row(Row(1, 2)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + } + } + + test("withField should add field to struct because casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))), + Row(Row(1, 1, 2)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false), + StructField("A", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))), + Row(Row(1, 1, 2)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + } + } + + private lazy val mixedCaseStructLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, 1), Row(1, 1))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false), + StructField("B", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + test("withField should replace nested field in struct even if casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswerAndSchema( + mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))), + Row(Row(Row(2, 1), Row(1, 1))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("A", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false), + StructField("B", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))), + Row(Row(Row(1, 1), Row(2, 1))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false), + StructField("b", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + } + + test("withField should throw an exception because casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + intercept[AnalysisException] { + mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))) + }.getMessage should include("No such struct field A in a, B") + + intercept[AnalysisException] { + mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))) + }.getMessage should include("No such struct field b in a, B") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index bd3f48078374d..e72b8ce860b28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -297,13 +297,12 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { // When generating expected results at here, we need to follow the implementation of // Rand expression. - def expected(df: DataFrame): Seq[Row] = { + def expected(df: DataFrame): Seq[Row] = df.rdd.collectPartitions().zipWithIndex.flatMap { case (data, index) => val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) data.filter(_.getInt(0) < rng.nextDouble() * 10) - } - } + }.toSeq val union = df1.union(df2) checkAnswer( @@ -506,4 +505,35 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { check(lit(2).cast("int"), $"c" === 2, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) check(lit(2).cast("int"), $"c" =!= 2, Seq()) } + + test("SPARK-29358: Make unionByName optionally fill missing columns with nulls") { + var df1 = Seq(1, 2, 3).toDF("a") + var df2 = Seq(3, 1, 2).toDF("b") + val df3 = Seq(2, 3, 1).toDF("c") + val unionDf = df1.unionByName(df2.unionByName(df3, true), true) + checkAnswer(unionDf, + Row(1, null, null) :: Row(2, null, null) :: Row(3, null, null) :: // df1 + Row(null, 3, null) :: Row(null, 1, null) :: Row(null, 2, null) :: // df2 + Row(null, null, 2) :: Row(null, null, 3) :: Row(null, null, 1) :: Nil // df3 + ) + + df1 = Seq((1, 2)).toDF("a", "c") + df2 = Seq((3, 4, 5)).toDF("a", "b", "c") + checkAnswer(df1.unionByName(df2, true), + Row(1, 2, null) :: Row(3, 5, 4) :: Nil) + checkAnswer(df2.unionByName(df1, true), + Row(3, 4, 5) :: Row(1, null, 2) :: Nil) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + df2 = Seq((3, 4, 5)).toDF("a", "B", "C") + val union1 = df1.unionByName(df2, true) + val union2 = df2.unionByName(df1, true) + + checkAnswer(union1, Row(1, 2, null, null) :: Row(3, null, 4, 5) :: Nil) + checkAnswer(union2, Row(3, 4, 5, null) :: Row(1, null, null, 2) :: Nil) + + assert(union1.schema.fieldNames === Array("a", "c", "B", "C")) + assert(union2.schema.fieldNames === Array("a", "B", "C", "c")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8359dff674a87..52ef5895ed9ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -195,22 +195,14 @@ class DataFrameSuite extends QueryTest private def assertDecimalSumOverflow( df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = { if (!ansiEnabled) { - try { - checkAnswer(df, expectedAnswer) - } catch { - case e: SparkException if e.getCause.isInstanceOf[ArithmeticException] => - // This is an existing bug that we can write overflowed decimal to UnsafeRow but fail - // to read it. - assert(e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) - } + checkAnswer(df, expectedAnswer) } else { val e = intercept[SparkException] { - df.collect + df.collect() } assert(e.getCause.isInstanceOf[ArithmeticException]) assert(e.getCause.getMessage.contains("cannot be represented as Decimal") || - e.getCause.getMessage.contains("Overflow in sum of decimals") || - e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) + e.getCause.getMessage.contains("Overflow in sum of decimals")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index ac2ebd8bd748b..508eefafd0754 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -336,7 +336,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("source") .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(years($"ts")) .create() @@ -350,7 +349,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("source") .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(months($"ts")) .create() @@ -364,7 +362,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("source") .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(days($"ts")) .create() @@ -378,7 +375,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.table("source") .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(hours($"ts")) .create() @@ -391,7 +387,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo test("Create: partitioned by bucket(4, id)") { spark.table("source") .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(bucket(4, $"id")) .create() @@ -596,7 +591,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified", lit("America/Los_Angeles") as "timezone")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy( years($"ts.created"), months($"ts.created"), days($"ts.created"), hours($"ts.created"), years($"ts.modified"), months($"ts.modified"), days($"ts.modified"), hours($"ts.modified") @@ -624,7 +618,6 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified", lit("America/Los_Angeles") as "timezone")) .writeTo("testcat.table_name") - .tableProperty("allow-unsupported-transforms", "true") .partitionedBy(bucket(4, $"ts.timezone")) .create() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 124b58483d24f..2be86b9ad6208 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -223,16 +223,6 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSparkSession { checkDataset(Seq(Queue(true)).toDS(), Queue(true)) checkDataset(Seq(Queue("test")).toDS(), Queue("test")) checkDataset(Seq(Queue(Tuple1(1))).toDS(), Queue(Tuple1(1))) - - checkDataset(Seq(ArrayBuffer(1)).toDS(), ArrayBuffer(1)) - checkDataset(Seq(ArrayBuffer(1.toLong)).toDS(), ArrayBuffer(1.toLong)) - checkDataset(Seq(ArrayBuffer(1.toDouble)).toDS(), ArrayBuffer(1.toDouble)) - checkDataset(Seq(ArrayBuffer(1.toFloat)).toDS(), ArrayBuffer(1.toFloat)) - checkDataset(Seq(ArrayBuffer(1.toByte)).toDS(), ArrayBuffer(1.toByte)) - checkDataset(Seq(ArrayBuffer(1.toShort)).toDS(), ArrayBuffer(1.toShort)) - checkDataset(Seq(ArrayBuffer(true)).toDS(), ArrayBuffer(true)) - checkDataset(Seq(ArrayBuffer("test")).toDS(), ArrayBuffer("test")) - checkDataset(Seq(ArrayBuffer(Tuple1(1))).toDS(), ArrayBuffer(Tuple1(1))) } test("sequence and product combinations") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index 1ad97185a564a..70303792fdf1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.execution.adaptive.{DisableAdaptiveExecutionSuite, E import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} trait ExplainSuiteHelper extends QueryTest with SharedSparkSession { @@ -360,6 +360,54 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite } } } + + test("Explain formatted output for scan operator for datasource V2") { + withTempDir { dir => + Seq("parquet", "orc", "csv", "json").foreach { fmt => + val basePath = dir.getCanonicalPath + "/" + fmt + val pushFilterMaps = Map ( + "parquet" -> + "|PushedFilers: \\[.*\\(id\\), .*\\(value\\), .*\\(id,1\\), .*\\(value,2\\)\\]", + "orc" -> + "|PushedFilers: \\[.*\\(id\\), .*\\(value\\), .*\\(id,1\\), .*\\(value,2\\)\\]", + "csv" -> + "|PushedFilers: \\[IsNotNull\\(value\\), GreaterThan\\(value,2\\)\\]", + "json" -> + "|remove_marker" + ) + val expected_plan_fragment1 = + s""" + |\\(1\\) BatchScan + |Output \\[2\\]: \\[value#x, id#x\\] + |DataFilters: \\[isnotnull\\(value#x\\), \\(value#x > 2\\)\\] + |Format: $fmt + |Location: InMemoryFileIndex\\[.*\\] + |PartitionFilters: \\[isnotnull\\(id#x\\), \\(id#x > 1\\)\\] + ${pushFilterMaps.get(fmt).get} + |ReadSchema: struct\\ + |""".stripMargin.replaceAll("\nremove_marker", "").trim + + spark.range(10) + .select(col("id"), col("id").as("value")) + .write.option("header", true) + .partitionBy("id") + .format(fmt) + .save(basePath) + val readSchema = + StructType(Seq(StructField("id", IntegerType), StructField("value", IntegerType))) + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + val df = spark + .read + .schema(readSchema) + .option("header", true) + .format(fmt) + .load(basePath).where($"id" > 1 && $"value" > 2) + val normalizedOutput = getNormalizedExplain(df, FormattedMode) + assert(expected_plan_fragment1.r.findAllMatchIn(normalizedOutput).length == 1) + } + } + } + } } class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala index 234978b9ce176..9f4c24b46a9b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala @@ -570,4 +570,31 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP assert(joinHints == expectedHints) } } + + test("SPARK-32220: Non Cartesian Product Join Result Correct with SHUFFLE_REPLICATE_NL hint") { + withTempView("t1", "t2") { + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") + val df1 = sql("SELECT /*+ shuffle_replicate_nl(t1) */ * from t1 join t2 ON t1.key = t2.key") + val df2 = sql("SELECT * from t1 join t2 ON t1.key = t2.key") + assert(df1.collect().size == df2.collect().size) + + val df3 = sql("SELECT /*+ shuffle_replicate_nl(t1) */ * from t1 join t2") + val df4 = sql("SELECT * from t1 join t2") + assert(df3.collect().size == df4.collect().size) + + val df5 = sql("SELECT /*+ shuffle_replicate_nl(t1) */ * from t1 join t2 ON t1.key < t2.key") + val df6 = sql("SELECT * from t1 join t2 ON t1.key < t2.key") + assert(df5.collect().size == df6.collect().size) + + val df7 = sql("SELECT /*+ shuffle_replicate_nl(t1) */ * from t1 join t2 ON t1.key < 2") + val df8 = sql("SELECT * from t1 join t2 ON t1.key < 2") + assert(df7.collect().size == df8.collect().size) + + + val df9 = sql("SELECT /*+ shuffle_replicate_nl(t1) */ * from t1 join t2 ON t2.key < 2") + val df10 = sql("SELECT * from t1 join t2 ON t2.key < 2") + assert(df9.collect().size == df10.collect().size) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index fe6775cc7f9b9..f24da6df67ca0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -712,7 +712,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |ON | big.key = small.a """.stripMargin), - expected + expected.toSeq ) } @@ -729,7 +729,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |ON | big.key = small.a """.stripMargin), - expected + expected.toSeq ) } } @@ -770,7 +770,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |ON | big.key = small.a """.stripMargin), - expected + expected.toSeq ) } @@ -787,7 +787,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |ON | big.key = small.a """.stripMargin), - expected + expected.toSeq ) } @@ -806,7 +806,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |ON | big.key = small.a """.stripMargin), - expected + expected.toSeq ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index e52d2262a6bf8..8469216901b05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -418,7 +418,7 @@ object QueryTest extends Assertions { } def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): Unit = { - getErrorMessageInCheckAnswer(df, expectedAnswer.asScala) match { + getErrorMessageInCheckAnswer(df, expectedAnswer.asScala.toSeq) match { case Some(errorMessage) => Assert.fail(errorMessage) case None => } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a219b91627b2b..989f304b1f07f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3521,6 +3521,45 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |""".stripMargin), Row(1)) } } + + test("SPARK-31875: remove hints from plan when spark.sql.optimizer.disableHints = true") { + withSQLConf(SQLConf.DISABLE_HINTS.key -> "true") { + withTempView("t1", "t2") { + Seq[Integer](1, 2).toDF("c1").createOrReplaceTempView("t1") + Seq[Integer](1, 2).toDF("c1").createOrReplaceTempView("t2") + val repartitionHints = Seq( + "COALESCE(2)", + "REPARTITION(c1)", + "REPARTITION(c1, 2)", + "REPARTITION_BY_RANGE(c1, 2)", + "REPARTITION_BY_RANGE(c1)" + ) + val joinHints = Seq( + "BROADCASTJOIN (t1)", + "MAPJOIN(t1)", + "SHUFFLE_MERGE(t1)", + "MERGEJOIN(t1)", + "SHUFFLE_REPLICATE_NL(t1)" + ) + + repartitionHints.foreach { hintName => + val sqlText = s"SELECT /*+ $hintName */ * FROM t1" + val sqlTextWithoutHint = "SELECT * FROM t1" + val expectedPlan = sql(sqlTextWithoutHint) + val actualPlan = sql(sqlText) + comparePlans(actualPlan.queryExecution.analyzed, expectedPlan.queryExecution.analyzed) + } + + joinHints.foreach { hintName => + val sqlText = s"SELECT /*+ $hintName */ * FROM t1 INNER JOIN t2 ON t1.c1 = t2.c1" + val sqlTextWithoutHint = "SELECT * FROM t1 INNER JOIN t2 ON t1.c1 = t2.c1" + val expectedPlan = sql(sqlTextWithoutHint) + val actualPlan = sql(sqlText) + comparePlans(actualPlan.queryExecution.analyzed, expectedPlan.queryExecution.analyzed) + } + } + } + } } case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index f0522dfeafaac..33247455b5cdf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -153,8 +153,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper .set(SQLConf.SHUFFLE_PARTITIONS, 4) /** List of test cases to ignore, in lower cases. */ - protected def blackList: Set[String] = Set( - "blacklist.sql" // Do NOT remove this one. It is here to test the blacklist functionality. + protected def ignoreList: Set[String] = Set( + "ignored.sql" // Do NOT remove this one. It is here to test the ignore functionality. ) // Create all the test cases. @@ -222,7 +222,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper name: String, inputFile: String, resultFile: String) extends TestCase with AnsiTest protected def createScalaTestCase(testCase: TestCase): Unit = { - if (blackList.exists(t => + if (ignoreList.exists(t => testCase.name.toLowerCase(Locale.ROOT).contains(t.toLowerCase(Locale.ROOT)))) { // Create a test case to ignore this case. ignore(testCase.name) { /* Do nothing */ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala index 4e85f739b95a2..1106a787cc9a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala @@ -238,7 +238,7 @@ abstract class ShowCreateTableSuite extends QueryTest with SQLTestUtils { table.copy( createTime = 0L, lastAccessTime = 0L, - properties = table.properties.filterKeys(!nondeterministicProps.contains(_)), + properties = table.properties.filterKeys(!nondeterministicProps.contains(_)).toMap, stats = None, ignoredProperties = Map.empty ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 347bc735a8b76..2bb9aa55e4579 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -992,7 +992,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark subqueryExpressions ++= (getSubqueryExpressions(s.plan) :+ s) s } - subqueryExpressions + subqueryExpressions.toSeq } private def getNumSorts(plan: LogicalPlan): Int = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index aacb625d7921f..d0d484ec434ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -85,7 +85,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest with TPCDSSchema { // List up the known queries having too large code in a generated function. // A JIRA file for `modified-q3` is as follows; // [SPARK-29128] Split predicate code in OR expressions - val blackListForMethodCodeSizeCheck = Set("modified-q3") + val excludeListForMethodCodeSizeCheck = Set("modified-q3") modifiedTPCDSQueries.foreach { name => val queryString = resourceToString(s"tpcds-modifiedQueries/$name.sql", @@ -94,7 +94,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest with TPCDSSchema { test(testName) { // check the plans can be properly generated val plan = sql(queryString).queryExecution.executedPlan - checkGeneratedCode(plan, !blackListForMethodCodeSizeCheck.contains(testName)) + checkGeneratedCode(plan, !excludeListForMethodCodeSizeCheck.contains(testName)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 5c1fe265c15d0..f0d5a61ad8006 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -18,13 +18,20 @@ package org.apache.spark.sql import java.math.BigDecimal +import java.sql.Timestamp +import java.time.{Instant, LocalDate} +import java.time.format.DateTimeFormatter +import org.apache.spark.SparkException import org.apache.spark.sql.api.java._ +import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{QueryExecution, SimpleMode} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand} import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand +import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions.{lit, struct, udf} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -33,6 +40,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.QueryExecutionListener private case class FunctionResult(f1: String, f2: String) +private case class LocalDateInstantType(date: LocalDate, instant: Instant) +private case class TimestampInstantType(t: Timestamp, instant: Instant) class UDFSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -504,23 +513,94 @@ class UDFSuite extends QueryTest with SharedSparkSession { } test("Using java.time.Instant in UDF") { - withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { - val expected = java.time.Instant.parse("2019-02-27T00:00:00Z") - val plusSec = udf((i: java.time.Instant) => i.plusSeconds(1)) - val df = spark.sql("SELECT TIMESTAMP '2019-02-26 23:59:59Z' as t") - .select(plusSec('t)) - assert(df.collect().toSeq === Seq(Row(expected))) - } + val dtf = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss") + val expected = java.time.Instant.parse("2019-02-27T00:00:00Z") + .atZone(DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .toLocalDateTime + .format(dtf) + val plusSec = udf((i: java.time.Instant) => i.plusSeconds(1)) + val df = spark.sql("SELECT TIMESTAMP '2019-02-26 23:59:59Z' as t") + .select(plusSec('t).cast(StringType)) + checkAnswer(df, Row(expected) :: Nil) } test("Using java.time.LocalDate in UDF") { - withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { - val expected = java.time.LocalDate.parse("2019-02-27") - val plusDay = udf((i: java.time.LocalDate) => i.plusDays(1)) - val df = spark.sql("SELECT DATE '2019-02-26' as d") - .select(plusDay('d)) - assert(df.collect().toSeq === Seq(Row(expected))) - } + val expected = java.time.LocalDate.parse("2019-02-27").toString + val plusDay = udf((i: java.time.LocalDate) => i.plusDays(1)) + val df = spark.sql("SELECT DATE '2019-02-26' as d") + .select(plusDay('d).cast(StringType)) + checkAnswer(df, Row(expected) :: Nil) + } + + test("Using combined types of Instant/LocalDate in UDF") { + val dtf = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss") + val date = LocalDate.parse("2019-02-26") + val instant = Instant.parse("2019-02-26T23:59:59Z") + val expectedDate = date.toString + val expectedInstant = + instant.atZone(DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .toLocalDateTime + .format(dtf) + val df = Seq((date, instant)).toDF("d", "i") + + // test normal case + spark.udf.register("buildLocalDateInstantType", + udf((d: LocalDate, i: Instant) => LocalDateInstantType(d, i))) + checkAnswer(df.selectExpr(s"buildLocalDateInstantType(d, i) as di") + .select('di.cast(StringType)), + Row(s"[$expectedDate, $expectedInstant]") :: Nil) + + // test null cases + spark.udf.register("buildLocalDateInstantType", + udf((d: LocalDate, i: Instant) => LocalDateInstantType(null, null))) + checkAnswer(df.selectExpr("buildLocalDateInstantType(d, i) as di"), + Row(Row(null, null))) + + spark.udf.register("buildLocalDateInstantType", + udf((d: LocalDate, i: Instant) => null.asInstanceOf[LocalDateInstantType])) + checkAnswer(df.selectExpr("buildLocalDateInstantType(d, i) as di"), + Row(null)) + } + + test("Using combined types of Instant/Timestamp in UDF") { + val dtf = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss") + val timestamp = Timestamp.valueOf("2019-02-26 23:59:59") + val instant = Instant.parse("2019-02-26T23:59:59Z") + val expectedTimestamp = timestamp.toLocalDateTime.format(dtf) + val expectedInstant = + instant.atZone(DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .toLocalDateTime + .format(dtf) + val df = Seq((timestamp, instant)).toDF("t", "i") + + // test normal case + spark.udf.register("buildTimestampInstantType", + udf((t: Timestamp, i: Instant) => TimestampInstantType(t, i))) + checkAnswer(df.selectExpr("buildTimestampInstantType(t, i) as ti") + .select('ti.cast(StringType)), + Row(s"[$expectedTimestamp, $expectedInstant]")) + + // test null cases + spark.udf.register("buildTimestampInstantType", + udf((t: Timestamp, i: Instant) => TimestampInstantType(null, null))) + checkAnswer(df.selectExpr("buildTimestampInstantType(t, i) as ti"), + Row(Row(null, null))) + + spark.udf.register("buildTimestampInstantType", + udf((t: Timestamp, i: Instant) => null.asInstanceOf[TimestampInstantType])) + checkAnswer(df.selectExpr("buildTimestampInstantType(t, i) as ti"), + Row(null)) + } + + test("SPARK-32154: return null with or without explicit type") { + // without explicit type + val udf1 = udf((i: String) => null) + assert(udf1.asInstanceOf[SparkUserDefinedFunction] .dataType === NullType) + checkAnswer(Seq("1").toDF("a").select(udf1('a)), Row(null) :: Nil) + // with explicit type + val udf2 = udf((i: String) => null.asInstanceOf[String]) + assert(udf2.asInstanceOf[SparkUserDefinedFunction].dataType === StringType) + checkAnswer(Seq("1").toDF("a").select(udf1('a)), Row(null) :: Nil) } test("SPARK-28321 0-args Java UDF should not be called only once") { @@ -669,4 +749,42 @@ class UDFSuite extends QueryTest with SharedSparkSession { val df = Seq(Array(Some(TestData(50, "2")), None)).toDF("col") checkAnswer(df.select(myUdf(Column("col"))), Row(100) :: Nil) } + + object MalformedClassObject extends Serializable { + class MalformedNonPrimitiveFunction extends (String => Int) with Serializable { + override def apply(v1: String): Int = v1.toInt / 0 + } + + class MalformedPrimitiveFunction extends (Int => Int) with Serializable { + override def apply(v1: Int): Int = v1 / 0 + } + } + + test("SPARK-32238: Use Utils.getSimpleName to avoid hitting Malformed class name") { + OuterScopes.addOuterScope(MalformedClassObject) + val f1 = new MalformedClassObject.MalformedNonPrimitiveFunction() + val f2 = new MalformedClassObject.MalformedPrimitiveFunction() + + val e1 = intercept[SparkException] { + Seq("20").toDF("col").select(udf(f1).apply(Column("col"))).collect() + } + assert(e1.getMessage.contains("UDFSuite$MalformedClassObject$MalformedNonPrimitiveFunction")) + + val e2 = intercept[SparkException] { + Seq(20).toDF("col").select(udf(f2).apply(Column("col"))).collect() + } + assert(e2.getMessage.contains("UDFSuite$MalformedClassObject$MalformedPrimitiveFunction")) + } + + test("SPARK-32307: Aggression that use map type input UDF as group expression") { + spark.udf.register("key", udf((m: Map[String, String]) => m.keys.head.toInt)) + Seq(Map("1" -> "one", "2" -> "two")).toDF("a").createOrReplaceTempView("t") + checkAnswer(sql("SELECT key(a) AS k FROM t GROUP BY key(a)"), Row(1) :: Nil) + } + + test("SPARK-32307: Aggression that use array type input UDF as group expression") { + spark.udf.register("key", udf((m: Array[Int]) => m.head)) + Seq(Array(1)).toDF("a").createOrReplaceTempView("t") + checkAnswer(sql("SELECT key(a) AS k FROM t GROUP BY key(a)"), Row(1) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index a5f904c621e6e..9daa69ce9f155 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -178,4 +178,14 @@ class UnsafeRowSuite extends SparkFunSuite { // Makes sure hashCode on unsafe array won't crash unsafeRow.getArray(0).hashCode() } + + test("SPARK-32018: setDecimal with overflowed value") { + val d1 = new Decimal().set(BigDecimal("10000000000000000000")).toPrecision(38, 18) + val row = InternalRow.apply(d1) + val unsafeRow = UnsafeProjection.create(Array[DataType](DecimalType(38, 18))).apply(row) + assert(unsafeRow.getDecimal(0, 38, 18) === d1) + val d2 = (d1 * Decimal(10)).toPrecision(39, 18) + unsafeRow.setDecimal(0, d2, 38) + assert(unsafeRow.getDecimal(0, 38, 18) === null) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index f7f4df8f2d2e9..85aea3ce41ecc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.connector +import java.sql.Timestamp +import java.time.LocalDate + import scala.collection.JavaConverters._ import org.apache.spark.SparkException @@ -27,7 +30,7 @@ import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} -import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION +import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG_IMPLEMENTATION} import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources.SimpleScanSource import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType} @@ -1647,7 +1650,6 @@ class DataSourceV2SQLSuite """ |CREATE TABLE testcat.t (id int, `a.b` string) USING foo |CLUSTERED BY (`a.b`) INTO 4 BUCKETS - |OPTIONS ('allow-unsupported-transforms'=true) """.stripMargin) val testCatalog = catalog("testcat").asTableCatalog.asInstanceOf[InMemoryTableCatalog] @@ -2494,6 +2496,38 @@ class DataSourceV2SQLSuite } } + test("SPARK-32168: INSERT OVERWRITE - hidden days partition - dynamic mode") { + def testTimestamp(daysOffset: Int): Timestamp = { + Timestamp.valueOf(LocalDate.of(2020, 1, 1 + daysOffset).atStartOfDay()) + } + + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + val t1 = s"${catalogAndNamespace}tbl" + withTable(t1) { + val df = spark.createDataFrame(Seq( + (testTimestamp(1), "a"), + (testTimestamp(2), "b"), + (testTimestamp(3), "c"))).toDF("ts", "data") + df.createOrReplaceTempView("source_view") + + sql(s"CREATE TABLE $t1 (ts timestamp, data string) " + + s"USING $v2Format PARTITIONED BY (days(ts))") + sql(s"INSERT INTO $t1 VALUES " + + s"(CAST(date_add('2020-01-01', 2) AS timestamp), 'dummy'), " + + s"(CAST(date_add('2020-01-01', 4) AS timestamp), 'keep')") + sql(s"INSERT OVERWRITE TABLE $t1 SELECT ts, data FROM source_view") + + val expected = spark.createDataFrame(Seq( + (testTimestamp(1), "a"), + (testTimestamp(2), "b"), + (testTimestamp(3), "c"), + (testTimestamp(4), "keep"))).toDF("ts", "data") + + verifyTable(t1, expected) + } + } + } + private def testV1Command(sqlCommand: String, sqlParams: String): Unit = { val e = intercept[AnalysisException] { sql(s"$sqlCommand $sqlParams") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala index b88ad5218fcd2..2cc7a1f994645 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala @@ -446,21 +446,18 @@ trait InsertIntoSQLOnlyTests } } - test("InsertInto: overwrite - multiple static partitions - dynamic mode") { - // Since all partitions are provided statically, this should be supported by everyone - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { - val t1 = s"${catalogAndNamespace}tbl" - withTableAndData(t1) { view => - sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + - s"USING $v2Format PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM $view") - verifyTable(t1, Seq( - (2, "a", 2), - (2, "b", 2), - (2, "c", 2), - (4, "keep", 2)).toDF("id", "data", "p")) - } + dynamicOverwriteTest("InsertInto: overwrite - multiple static partitions - dynamic mode") { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + + s"USING $v2Format PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM $view") + verifyTable(t1, Seq( + (2, "a", 2), + (2, "b", 2), + (2, "c", 2), + (4, "keep", 2)).toDF("id", "data", "p")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala index b29de9c4adbaa..98aba3ba25f17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -27,32 +27,29 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSparkContext { private val random = new java.util.Random() - private var taskContext: TaskContext = _ - - override def afterAll(): Unit = try { - TaskContext.unset() - } finally { - super.afterAll() - } private def withExternalArray(inMemoryThreshold: Int, spillThreshold: Int) (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = { sc = new SparkContext("local", "test", new SparkConf(false)) - taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get) + val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get) TaskContext.setTaskContext(taskContext) - val array = new ExternalAppendOnlyUnsafeRowArray( - taskContext.taskMemoryManager(), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - taskContext, - 1024, - SparkEnv.get.memoryManager.pageSizeBytes, - inMemoryThreshold, - spillThreshold) - try f(array) finally { - array.clear() + try { + val array = new ExternalAppendOnlyUnsafeRowArray( + taskContext.taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + taskContext, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + inMemoryThreshold, + spillThreshold) + try f(array) finally { + array.clear() + } + } finally { + TaskContext.unset() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/IntervalBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/IntervalBenchmark.scala index 96ad453aeb2d7..a9696e6718de8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/IntervalBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/IntervalBenchmark.scala @@ -109,11 +109,11 @@ object IntervalBenchmark extends SqlBasedBenchmark { // The first 2 cases are used to show the overhead of preparing the interval string. addCase(benchmark, cardinality, "prepare string w/ interval", buildString(true, timeUnits)) addCase(benchmark, cardinality, "prepare string w/o interval", buildString(false, timeUnits)) - addCase(benchmark, cardinality, intervalToTest) // Only years + addCase(benchmark, cardinality, intervalToTest.toSeq) // Only years for (unit <- timeUnits) { intervalToTest.append(unit) - addCase(benchmark, cardinality, intervalToTest) + addCase(benchmark, cardinality, intervalToTest.toSeq) } benchmark.run() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 8b859e951b9b9..d51eafa5a8aed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -497,6 +497,26 @@ abstract class SchemaPruningSuite Row(Row("Janet", null, "Jones"), "Jones") ::Nil) } + testSchemaPruning("SPARK-32163: nested pruning should work even with cosmetic variations") { + withTempView("contact_alias") { + sql("select * from contacts") + .repartition(100, col("name.first"), col("name.last")) + .selectExpr("name").createOrReplaceTempView("contact_alias") + + val query1 = sql("select name.first from contact_alias") + checkScan(query1, "struct>") + checkAnswer(query1, Row("Jane") :: Row("John") :: Row("Jim") :: Row("Janet") ::Nil) + + sql("select * from contacts") + .select(explode(col("friends.first")), col("friends")) + .createOrReplaceTempView("contact_alias") + + val query2 = sql("select friends.middle, col from contact_alias") + checkScan(query2, "struct>>") + checkAnswer(query2, Row(Array("Z."), "Susan") :: Nil) + } + } + protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = { test(s"Spark vectorized reader - without partition data column - $testName") { withSQLConf(vectorizedReaderEnabledKey -> "true") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index c7448b12626be..de01099f2db55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2488,7 +2488,7 @@ abstract class JsonSuite extends QueryTest with SharedSparkSession with TestJson .json(testFile("test-data/utf16LE.json")) .count() } - assert(exception.getMessage.contains("encoding must not be included in the blacklist")) + assert(exception.getMessage.contains("encoding must not be included in the denyList")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 8bbf81efff316..ce726046c3215 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -220,9 +220,9 @@ trait SQLMetricsTestUtils extends SQLTestUtils { (nodeName, nodeMetrics.mapValues(expectedMetricValue => (actualMetricValue: Any) => { actualMetricValue.toString.matches(expectedMetricValue.toString) - })) + }).toMap) } - testSparkPlanMetricsWithPredicates(df, expectedNumOfJobs, expectedMetricsPredicates, + testSparkPlanMetricsWithPredicates(df, expectedNumOfJobs, expectedMetricsPredicates.toMap, enableWholeStage) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleSuite.scala deleted file mode 100644 index 54ec4a8352c1b..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleSuite.scala +++ /dev/null @@ -1,423 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.continuous.shuffle - -import java.util.UUID - -import scala.language.implicitConversions - -import org.apache.spark.{HashPartitioner, TaskContext, TaskContextImpl} -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.{DataType, IntegerType, StringType} -import org.apache.spark.unsafe.types.UTF8String - -class ContinuousShuffleSuite extends StreamTest { - // In this unit test, we emulate that we're in the task thread where - // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context - // thread local to be set. - var ctx: TaskContextImpl = _ - - override def beforeEach(): Unit = { - super.beforeEach() - ctx = TaskContext.empty() - TaskContext.setTaskContext(ctx) - } - - override def afterEach(): Unit = { - ctx.markTaskCompleted(None) - TaskContext.unset() - ctx = null - super.afterEach() - } - - private implicit def unsafeRow(value: Int) = { - UnsafeProjection.create(Array(IntegerType : DataType))( - new GenericInternalRow(Array(value: Any))) - } - - private def unsafeRow(value: String) = { - UnsafeProjection.create(Array(StringType : DataType))( - new GenericInternalRow(Array(UTF8String.fromString(value): Any))) - } - - private def send(endpoint: RpcEndpointRef, messages: RPCContinuousShuffleMessage*) = { - messages.foreach(endpoint.askSync[Unit](_)) - } - - private def readRDDEndpoint(rdd: ContinuousShuffleReadRDD) = { - rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - } - - private def readEpoch(rdd: ContinuousShuffleReadRDD) = { - rdd.compute(rdd.partitions(0), ctx).toSeq.map(_.getInt(0)) - } - - test("reader - one epoch") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverRow(0, unsafeRow(111)), - ReceiverRow(0, unsafeRow(222)), - ReceiverRow(0, unsafeRow(333)), - ReceiverEpochMarker(0) - ) - - val iter = rdd.compute(rdd.partitions(0), ctx) - assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) - } - - test("reader - multiple epochs") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverRow(0, unsafeRow(111)), - ReceiverEpochMarker(0), - ReceiverRow(0, unsafeRow(222)), - ReceiverRow(0, unsafeRow(333)), - ReceiverEpochMarker(0) - ) - - val firstEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111)) - - val secondEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) - } - - test("reader - empty epochs") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - - send( - endpoint, - ReceiverEpochMarker(0), - ReceiverEpochMarker(0), - ReceiverRow(0, unsafeRow(111)), - ReceiverEpochMarker(0), - ReceiverEpochMarker(0), - ReceiverEpochMarker(0) - ) - - assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) - assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) - - val thirdEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(thirdEpoch.toSeq.map(_.getInt(0)) == Seq(111)) - - assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) - assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) - } - - test("reader - multiple partitions") { - val rdd = new ContinuousShuffleReadRDD( - sparkContext, - numPartitions = 5, - endpointNames = Seq.fill(5)(s"endpt-${UUID.randomUUID()}")) - // Send all data before processing to ensure there's no crossover. - for (p <- rdd.partitions) { - val part = p.asInstanceOf[ContinuousShuffleReadPartition] - // Send index for identification. - send( - part.endpoint, - ReceiverRow(0, unsafeRow(part.index)), - ReceiverEpochMarker(0) - ) - } - - for (p <- rdd.partitions) { - val part = p.asInstanceOf[ContinuousShuffleReadPartition] - val iter = rdd.compute(part, ctx) - assert(iter.next().getInt(0) == part.index) - assert(!iter.hasNext) - } - } - - test("reader - blocks waiting for new rows") { - val rdd = new ContinuousShuffleReadRDD( - sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue) - val epoch = rdd.compute(rdd.partitions(0), ctx) - - val readRowThread = new Thread { - override def run(): Unit = { - try { - epoch.next().getInt(0) - } catch { - case _: InterruptedException => // do nothing - expected at test ending - } - } - } - - try { - readRowThread.start() - eventually(timeout(streamingTimeout)) { - assert(readRowThread.getState == Thread.State.TIMED_WAITING) - } - } finally { - readRowThread.interrupt() - readRowThread.join() - } - } - - test("reader - multiple writers") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverRow(0, unsafeRow("writer0-row0")), - ReceiverRow(1, unsafeRow("writer1-row0")), - ReceiverRow(2, unsafeRow("writer2-row0")), - ReceiverEpochMarker(0), - ReceiverEpochMarker(1), - ReceiverEpochMarker(2) - ) - - val firstEpoch = rdd.compute(rdd.partitions(0), ctx) - assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet == - Set("writer0-row0", "writer1-row0", "writer2-row0")) - } - - test("reader - epoch only ends when all writers send markers") { - val rdd = new ContinuousShuffleReadRDD( - sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverRow(0, unsafeRow("writer0-row0")), - ReceiverRow(1, unsafeRow("writer1-row0")), - ReceiverRow(2, unsafeRow("writer2-row0")), - ReceiverEpochMarker(0), - ReceiverEpochMarker(2) - ) - - val epoch = rdd.compute(rdd.partitions(0), ctx) - val rows = (0 until 3).map(_ => epoch.next()).toSet - assert(rows.map(_.getUTF8String(0).toString) == - Set("writer0-row0", "writer1-row0", "writer2-row0")) - - // After checking the right rows, block until we get an epoch marker indicating there's no next. - // (Also fail the assertion if for some reason we get a row.) - - val readEpochMarkerThread = new Thread { - override def run(): Unit = { - assert(!epoch.hasNext) - } - } - - readEpochMarkerThread.start() - eventually(timeout(streamingTimeout)) { - assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) - } - - // Send the last epoch marker - now the epoch should finish. - send(endpoint, ReceiverEpochMarker(1)) - eventually(timeout(streamingTimeout)) { - !readEpochMarkerThread.isAlive - } - - // Join to pick up assertion failures. - readEpochMarkerThread.join(streamingTimeout.toMillis) - } - - test("reader - writer epochs non aligned") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - // We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should - // collate them as though the markers were aligned in the first place. - send( - endpoint, - ReceiverRow(0, unsafeRow("writer0-row0")), - ReceiverEpochMarker(0), - ReceiverRow(0, unsafeRow("writer0-row1")), - ReceiverEpochMarker(0), - ReceiverEpochMarker(0), - - ReceiverEpochMarker(1), - ReceiverRow(1, unsafeRow("writer1-row0")), - ReceiverEpochMarker(1), - ReceiverRow(1, unsafeRow("writer1-row1")), - ReceiverEpochMarker(1), - - ReceiverEpochMarker(2), - ReceiverEpochMarker(2), - ReceiverRow(2, unsafeRow("writer2-row0")), - ReceiverEpochMarker(2) - ) - - val firstEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet - assert(firstEpoch == Set("writer0-row0")) - - val secondEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet - assert(secondEpoch == Set("writer0-row1", "writer1-row0")) - - val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet - assert(thirdEpoch == Set("writer1-row1", "writer2-row0")) - } - - test("one epoch") { - val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val writer = new RPCContinuousShuffleWriter( - 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) - - writer.write(Iterator(1, 2, 3)) - - assert(readEpoch(reader) == Seq(1, 2, 3)) - } - - test("multiple epochs") { - val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val writer = new RPCContinuousShuffleWriter( - 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) - - writer.write(Iterator(1, 2, 3)) - writer.write(Iterator(4, 5, 6)) - - assert(readEpoch(reader) == Seq(1, 2, 3)) - assert(readEpoch(reader) == Seq(4, 5, 6)) - } - - test("empty epochs") { - val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val writer = new RPCContinuousShuffleWriter( - 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) - - writer.write(Iterator()) - writer.write(Iterator(1, 2)) - writer.write(Iterator()) - writer.write(Iterator()) - writer.write(Iterator(3, 4)) - writer.write(Iterator()) - - assert(readEpoch(reader) == Seq()) - assert(readEpoch(reader) == Seq(1, 2)) - assert(readEpoch(reader) == Seq()) - assert(readEpoch(reader) == Seq()) - assert(readEpoch(reader) == Seq(3, 4)) - assert(readEpoch(reader) == Seq()) - } - - test("blocks waiting for writer") { - val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val writer = new RPCContinuousShuffleWriter( - 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) - - val readerEpoch = reader.compute(reader.partitions(0), ctx) - - val readRowThread = new Thread { - override def run(): Unit = { - assert(readerEpoch.toSeq.map(_.getInt(0)) == Seq(1)) - } - } - readRowThread.start() - - eventually(timeout(streamingTimeout)) { - assert(readRowThread.getState == Thread.State.TIMED_WAITING) - } - - // Once we write the epoch the thread should stop waiting and succeed. - writer.write(Iterator(1)) - readRowThread.join(streamingTimeout.toMillis) - } - - test("multiple writer partitions") { - val numWriterPartitions = 3 - - val reader = new ContinuousShuffleReadRDD( - sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) - val writers = (0 until 3).map { idx => - new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) - } - - writers(0).write(Iterator(1, 4, 7)) - writers(1).write(Iterator(2, 5)) - writers(2).write(Iterator(3, 6)) - - writers(0).write(Iterator(4, 7, 10)) - writers(1).write(Iterator(5, 8)) - writers(2).write(Iterator(6, 9)) - - // Since there are multiple asynchronous writers, the original row sequencing is not guaranteed. - // The epochs should be deterministically preserved, however. - assert(readEpoch(reader).toSet == Seq(1, 2, 3, 4, 5, 6, 7).toSet) - assert(readEpoch(reader).toSet == Seq(4, 5, 6, 7, 8, 9, 10).toSet) - } - - test("reader epoch only ends when all writer partitions write it") { - val numWriterPartitions = 3 - - val reader = new ContinuousShuffleReadRDD( - sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) - val writers = (0 until 3).map { idx => - new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) - } - - writers(1).write(Iterator()) - writers(2).write(Iterator()) - - val readerEpoch = reader.compute(reader.partitions(0), ctx) - - val readEpochMarkerThread = new Thread { - override def run(): Unit = { - assert(!readerEpoch.hasNext) - } - } - - readEpochMarkerThread.start() - eventually(timeout(streamingTimeout)) { - assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) - } - - writers(0).write(Iterator()) - readEpochMarkerThread.join(streamingTimeout.toMillis) - } - - test("receiver stopped with row last") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverEpochMarker(0), - ReceiverRow(0, unsafeRow(111)) - ) - - ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader - eventually(timeout(streamingTimeout)) { - assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) - } - } - - test("receiver stopped with marker last") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverRow(0, unsafeRow(111)), - ReceiverEpochMarker(0) - ) - - ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader - eventually(timeout(streamingTimeout)) { - assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index e87bd11f0dca5..0fe339b93047a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -333,6 +333,6 @@ class TestForeachWriter extends ForeachWriter[Int] { override def close(errorOrNull: Throwable): Unit = { events += ForeachWriterSuite.Close(error = Option(errorOrNull)) - ForeachWriterSuite.addEvents(events) + ForeachWriterSuite.addEvents(events.toSeq) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 4d5cd109b7c24..b033761498ea7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -209,21 +209,24 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils (1L, 0, 0, createAccumulatorInfos(accumulatorUpdates)) ))) - checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(statusStore.executionMetrics(executionId), + accumulatorUpdates.mapValues(_ * 2).toMap) // Driver accumulator updates don't belong to this execution should be filtered and no // exception will be thrown. listener.onOtherEvent(SparkListenerDriverAccumUpdates(0, Seq((999L, 2L)))) - checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(statusStore.executionMetrics(executionId), + accumulatorUpdates.mapValues(_ * 2).toMap) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, accum updates) (0L, 0, 0, createAccumulatorInfos(accumulatorUpdates)), - (1L, 0, 0, createAccumulatorInfos(accumulatorUpdates.mapValues(_ * 2))) + (1L, 0, 0, createAccumulatorInfos(accumulatorUpdates.mapValues(_ * 2).toMap)) ))) - checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 3)) + checkAnswer(statusStore.executionMetrics(executionId), + accumulatorUpdates.mapValues(_ * 3).toMap) // Retrying a stage should reset the metrics listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1))) @@ -236,7 +239,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils (1L, 0, 1, createAccumulatorInfos(accumulatorUpdates)) ))) - checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(statusStore.executionMetrics(executionId), + accumulatorUpdates.mapValues(_ * 2).toMap) // Ignore the task end for the first attempt listener.onTaskEnd(SparkListenerTaskEnd( @@ -244,11 +248,12 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils stageAttemptId = 0, taskType = "", reason = null, - createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 100)), + createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 100).toMap), new ExecutorMetrics, null)) - checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(statusStore.executionMetrics(executionId), + accumulatorUpdates.mapValues(_ * 2).toMap) // Finish two tasks listener.onTaskEnd(SparkListenerTaskEnd( @@ -256,7 +261,7 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils stageAttemptId = 1, taskType = "", reason = null, - createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 2)), + createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 2).toMap), new ExecutorMetrics, null)) listener.onTaskEnd(SparkListenerTaskEnd( @@ -264,11 +269,12 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils stageAttemptId = 1, taskType = "", reason = null, - createTaskInfo(1, 0, accums = accumulatorUpdates.mapValues(_ * 3)), + createTaskInfo(1, 0, accums = accumulatorUpdates.mapValues(_ * 3).toMap), new ExecutorMetrics, null)) - checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 5)) + checkAnswer(statusStore.executionMetrics(executionId), + accumulatorUpdates.mapValues(_ * 5).toMap) // Summit a new stage listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0))) @@ -281,7 +287,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils (1L, 1, 0, createAccumulatorInfos(accumulatorUpdates)) ))) - checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 7)) + checkAnswer(statusStore.executionMetrics(executionId), + accumulatorUpdates.mapValues(_ * 7).toMap) // Finish two tasks listener.onTaskEnd(SparkListenerTaskEnd( @@ -289,7 +296,7 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils stageAttemptId = 0, taskType = "", reason = null, - createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 3)), + createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 3).toMap), new ExecutorMetrics, null)) listener.onTaskEnd(SparkListenerTaskEnd( @@ -297,11 +304,12 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils stageAttemptId = 0, taskType = "", reason = null, - createTaskInfo(1, 0, accums = accumulatorUpdates.mapValues(_ * 3)), + createTaskInfo(1, 0, accums = accumulatorUpdates.mapValues(_ * 3).toMap), new ExecutorMetrics, null)) - checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 11)) + checkAnswer(statusStore.executionMetrics(executionId), + accumulatorUpdates.mapValues(_ * 11).toMap) assertJobs(statusStore.execution(executionId), running = Seq(0)) @@ -315,7 +323,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils assertJobs(statusStore.execution(executionId), completed = Seq(0)) - checkAnswer(statusStore.executionMetrics(executionId), accumulatorUpdates.mapValues(_ * 11)) + checkAnswer(statusStore.executionMetrics(executionId), + accumulatorUpdates.mapValues(_ * 11).toMap) } test("control a plan explain mode in listeners via SQLConf") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 8d5439534b513..5e401f5136019 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -873,7 +873,7 @@ trait StreamTest extends QueryTest with SharedSparkSession with TimeLimits with } if(!running) { actions += StartStream() } addCheck() - testStream(ds)(actions: _*) + testStream(ds)(actions.toSeq: _*) } object AwaitTerminationTester { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index caca749f9dd1e..b182727408bbf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming import java.io.File +import java.sql.Timestamp import java.util.{Locale, UUID} import scala.util.Random @@ -996,4 +997,47 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with ) } } + + test("SPARK-32148 stream-stream join regression on Spark 3.0.0") { + val input1 = MemoryStream[(Timestamp, String, String)] + val df1 = input1.toDF + .selectExpr("_1 as eventTime", "_2 as id", "_3 as comment") + .withWatermark(s"eventTime", "2 minutes") + + val input2 = MemoryStream[(Timestamp, String, String)] + val df2 = input2.toDF + .selectExpr("_1 as eventTime", "_2 as id", "_3 as name") + .withWatermark(s"eventTime", "4 minutes") + + val joined = df1.as("left") + .join(df2.as("right"), + expr(""" + |left.id = right.id AND left.eventTime BETWEEN + | right.eventTime - INTERVAL 30 seconds AND + | right.eventTime + INTERVAL 30 seconds + """.stripMargin), + joinType = "leftOuter") + + val inputDataForInput1 = Seq( + (Timestamp.valueOf("2020-01-01 00:00:00"), "abc", "has no join partner"), + (Timestamp.valueOf("2020-01-02 00:00:00"), "abc", "joined with A"), + (Timestamp.valueOf("2020-01-02 01:00:00"), "abc", "joined with B")) + + val inputDataForInput2 = Seq( + (Timestamp.valueOf("2020-01-02 00:00:10"), "abc", "A"), + (Timestamp.valueOf("2020-01-02 00:59:59"), "abc", "B"), + (Timestamp.valueOf("2020-01-02 02:00:00"), "abc", "C")) + + val expectedOutput = Seq( + (Timestamp.valueOf("2020-01-01 00:00:00"), "abc", "has no join partner", null, null, null), + (Timestamp.valueOf("2020-01-02 00:00:00"), "abc", "joined with A", + Timestamp.valueOf("2020-01-02 00:00:10"), "abc", "A"), + (Timestamp.valueOf("2020-01-02 01:00:00"), "abc", "joined with B", + Timestamp.valueOf("2020-01-02 00:59:59"), "abc", "B")) + + testStream(joined)( + MultiAddData((input1, inputDataForInput1), (input2, inputDataForInput2)), + CheckNewAnswer(expectedOutput.head, expectedOutput.tail: _*) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 6e08b88f538df..26158f4d639ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -323,7 +323,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { actions += AssertOnQuery { q => q.recentProgress.size > 1 && q.recentProgress.size <= 11 } - testStream(input.toDS)(actions: _*) + testStream(input.toDS)(actions.toSeq: _*) spark.sparkContext.listenerBus.waitUntilEmpty() // 11 is the max value of the possible numbers of events. assert(numProgressEvent > 1 && numProgressEvent <= 11) @@ -559,11 +559,11 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { private val _progressEvents = new mutable.Queue[StreamingQueryProgress] def progressEvents: Seq[StreamingQueryProgress] = _progressEvents.synchronized { - _progressEvents.filter(_.numInputRows > 0) + _progressEvents.filter(_.numInputRows > 0).toSeq } def allProgressEvents: Seq[StreamingQueryProgress] = _progressEvents.synchronized { - _progressEvents.clone() + _progressEvents.clone().toSeq } def reset(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index 98e2342c78e56..ec61102804ea3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -316,7 +316,7 @@ object StreamingQueryStatusAndProgressSuite { timestamp = "2016-12-05T20:54:20.827Z", batchId = 2L, batchDuration = 0L, - durationMs = new java.util.HashMap(Map("total" -> 0L).mapValues(long2Long).asJava), + durationMs = new java.util.HashMap(Map("total" -> 0L).mapValues(long2Long).toMap.asJava), eventTime = new java.util.HashMap(Map( "max" -> "2016-12-05T20:54:20.827Z", "min" -> "2016-12-05T20:54:20.827Z", @@ -326,7 +326,7 @@ object StreamingQueryStatusAndProgressSuite { numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 3, numRowsDroppedByWatermark = 0, customMetrics = new java.util.HashMap(Map("stateOnCurrentVersionSizeBytes" -> 2L, "loadedMapCacheHitCount" -> 1L, "loadedMapCacheMissCount" -> 0L) - .mapValues(long2Long).asJava) + .mapValues(long2Long).toMap.asJava) )), sources = Array( new SourceProgress( @@ -351,7 +351,7 @@ object StreamingQueryStatusAndProgressSuite { timestamp = "2016-12-05T20:54:20.827Z", batchId = 2L, batchDuration = 0L, - durationMs = new java.util.HashMap(Map("total" -> 0L).mapValues(long2Long).asJava), + durationMs = new java.util.HashMap(Map("total" -> 0L).mapValues(long2Long).toMap.asJava), // empty maps should be handled correctly eventTime = new java.util.HashMap(Map.empty[String, String].asJava), stateOperators = Array(new StateOperatorProgress( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala deleted file mode 100644 index 3ec4750c59fc5..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming.continuous - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED -import org.apache.spark.sql.streaming.OutputMode - -class ContinuousAggregationSuite extends ContinuousSuiteBase { - import testImplicits._ - - test("not enabled") { - val ex = intercept[AnalysisException] { - val input = ContinuousMemoryStream.singlePartition[Int] - testStream(input.toDF().agg(max('value)), OutputMode.Complete)() - } - - assert(ex.getMessage.contains( - "In continuous processing mode, coalesce(1) must be called before aggregate operation")) - } - - test("basic") { - withSQLConf((UNSUPPORTED_OPERATION_CHECK_ENABLED.key, "false")) { - val input = ContinuousMemoryStream.singlePartition[Int] - - testStream(input.toDF().agg(max('value)), OutputMode.Complete)( - AddData(input, 0, 1, 2), - CheckAnswer(2), - StopStream, - AddData(input, 3, 4, 5), - StartStream(), - CheckAnswer(5), - AddData(input, -1, -2, -3), - CheckAnswer(5)) - } - } - - test("multiple partitions with coalesce") { - val input = ContinuousMemoryStream[Int] - - val df = input.toDF().coalesce(1).agg(max('value)) - - testStream(df, OutputMode.Complete)( - AddData(input, 0, 1, 2), - CheckAnswer(2), - StopStream, - AddData(input, 3, 4, 5), - StartStream(), - CheckAnswer(5), - AddData(input, -1, -2, -3), - CheckAnswer(5)) - } - - test("multiple partitions with coalesce - multiple transformations") { - val input = ContinuousMemoryStream[Int] - - // We use a barrier to make sure predicates both before and after coalesce work - val df = input.toDF() - .select('value as 'copy, 'value) - .where('copy =!= 1) - .logicalPlan - .coalesce(1) - .where('copy =!= 2) - .agg(max('value)) - - testStream(df, OutputMode.Complete)( - AddData(input, 0, 1, 2), - CheckAnswer(0), - StopStream, - AddData(input, 3, 4, 5), - StartStream(), - CheckAnswer(5), - AddData(input, -1, -2, -3), - CheckAnswer(5)) - } - - test("multiple partitions with multiple coalesce") { - val input = ContinuousMemoryStream[Int] - - val df = input.toDF() - .coalesce(1) - .logicalPlan - .coalesce(1) - .select('value as 'copy, 'value) - .agg(max('value)) - - testStream(df, OutputMode.Complete)( - AddData(input, 0, 1, 2), - CheckAnswer(2), - StopStream, - AddData(input, 3, 4, 5), - StartStream(), - CheckAnswer(5), - AddData(input, -1, -2, -3), - CheckAnswer(5)) - } - - test("repeated restart") { - withSQLConf((UNSUPPORTED_OPERATION_CHECK_ENABLED.key, "false")) { - val input = ContinuousMemoryStream.singlePartition[Int] - - testStream(input.toDF().agg(max('value)), OutputMode.Complete)( - AddData(input, 0, 1, 2), - CheckAnswer(2), - StopStream, - StartStream(), - StopStream, - StartStream(), - StopStream, - StartStream(), - AddData(input, 0), - CheckAnswer(2), - AddData(input, 5), - CheckAnswer(5)) - } - } -} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index eae5d5d4bcfa9..57ed15a76a893 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -333,8 +333,8 @@ private[hive] class SparkExecuteStatementOperation( synchronized { if (!getStatus.getState.isTerminal) { logInfo(s"Cancel query with $statementId") - cleanup() setState(OperationState.CANCELED) + cleanup() HiveThriftServer2.eventManager.onStatementCanceled(statementId) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala index 0acd1b3e9899a..446669d08e76b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala @@ -46,8 +46,8 @@ private[hive] trait SparkOperation extends Operation with Logging { } abstract override def close(): Unit = { - cleanup() super.close() + cleanup() logInfo(s"Close statement with $statementId") HiveThriftServer2.eventManager.onOperationClosed(statementId) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 233e6224a10d9..109b7f4bb31bb 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils -/** A singleton object for the master program. The slaves should not access this. */ +/** A singleton object for the master program. The executors should not access this. */ private[hive] object SparkSQLEnv extends Logging { logDebug("Initializing SparkSQLEnv") diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 8546421a86927..2064a99137bf9 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -132,6 +132,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { } var next = 0 + val foundMasterAndApplicationIdMessage = Promise.apply[Unit]() val foundAllExpectedAnswers = Promise.apply[Unit]() val buffer = new ArrayBuffer[String]() val lock = new Object @@ -143,6 +144,10 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { log.info(newLine) buffer += newLine + if (line.startsWith("Spark master: ") && line.contains("Application Id: ")) { + foundMasterAndApplicationIdMessage.trySuccess(()) + } + // If we haven't found all expected answers and another expected answer comes up... if (next < expectedAnswers.size && line.contains(expectedAnswers(next))) { log.info(s"$source> found expected output line $next: '${expectedAnswers(next)}'") @@ -172,7 +177,18 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() try { - ThreadUtils.awaitResult(foundAllExpectedAnswers.future, timeout) + val timeoutForQuery = if (!extraArgs.contains("-e")) { + // Wait for for cli driver to boot, up to two minutes + ThreadUtils.awaitResult(foundMasterAndApplicationIdMessage.future, 2.minutes) + log.info("Cli driver is booted. Waiting for expected answers.") + // Given timeout is applied after the cli driver is ready + timeout + } else { + // There's no boot message if -e option is provided, just extend timeout long enough + // so that the bootup duration is counted on the timeout + 2.minutes + timeout + } + ThreadUtils.awaitResult(foundAllExpectedAnswers.future, timeoutForQuery) log.info("Found all expected output.") } catch { case cause: Throwable => val message = @@ -194,7 +210,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { } finally { if (!process.waitFor(1, MINUTES)) { try { - fail("spark-sql did not exit gracefully.") + log.warn("spark-sql did not exit gracefully.") } finally { process.destroy() } @@ -447,7 +463,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { val jarFile = new File("../../sql/hive/src/test/resources/SPARK-21101-1.0.jar").getCanonicalPath val hiveContribJar = HiveTestJars.getHiveContribJar().getCanonicalPath runCliWithin( - 1.minute, + 2.minutes, Seq("--jars", s"$jarFile", "--conf", s"spark.hadoop.${ConfVars.HIVEAUXJARS}=$hiveContribJar"))( "CREATE TEMPORARY FUNCTION testjar AS" + diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveSessionImplSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveSessionImplSuite.scala index 05d540d782e31..42d86e98a7273 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveSessionImplSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveSessionImplSuite.scala @@ -16,24 +16,31 @@ */ package org.apache.spark.sql.hive.thriftserver +import java.lang.reflect.InvocationTargetException +import java.nio.ByteBuffer +import java.util.UUID + import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.hadoop.hive.conf.HiveConf import org.apache.hive.service.cli.OperationHandle -import org.apache.hive.service.cli.operation.{GetCatalogsOperation, OperationManager} -import org.apache.hive.service.cli.session.{HiveSessionImpl, SessionManager} -import org.mockito.Mockito.{mock, verify, when} -import org.mockito.invocation.InvocationOnMock +import org.apache.hive.service.cli.operation.{GetCatalogsOperation, Operation, OperationManager} +import org.apache.hive.service.cli.session.{HiveSession, HiveSessionImpl, SessionManager} +import org.apache.hive.service.rpc.thrift.{THandleIdentifier, TOperationHandle, TOperationType} import org.apache.spark.SparkFunSuite class HiveSessionImplSuite extends SparkFunSuite { private var session: HiveSessionImpl = _ - private var operationManager: OperationManager = _ + private var operationManager: OperationManagerMock = _ override def beforeAll() { super.beforeAll() + val sessionManager = new SessionManager(null) + operationManager = new OperationManagerMock() + session = new HiveSessionImpl( ThriftserverShimUtils.testedProtocolVersions.head, "", @@ -41,17 +48,8 @@ class HiveSessionImplSuite extends SparkFunSuite { new HiveConf(), "" ) - val sessionManager = mock(classOf[SessionManager]) session.setSessionManager(sessionManager) - operationManager = mock(classOf[OperationManager]) session.setOperationManager(operationManager) - when(operationManager.newGetCatalogsOperation(session)).thenAnswer( - (_: InvocationOnMock) => { - val operation = mock(classOf[GetCatalogsOperation]) - when(operation.getHandle).thenReturn(mock(classOf[OperationHandle])) - operation - } - ) session.open(Map.empty[String, String].asJava) } @@ -60,14 +58,59 @@ class HiveSessionImplSuite extends SparkFunSuite { val operationHandle1 = session.getCatalogs val operationHandle2 = session.getCatalogs - when(operationManager.closeOperation(operationHandle1)) - .thenThrow(classOf[NullPointerException]) - when(operationManager.closeOperation(operationHandle2)) - .thenThrow(classOf[NullPointerException]) - session.close() - verify(operationManager).closeOperation(operationHandle1) - verify(operationManager).closeOperation(operationHandle2) + assert(operationManager.getCalledHandles.contains(operationHandle1)) + assert(operationManager.getCalledHandles.contains(operationHandle2)) + } +} + +class GetCatalogsOperationMock(parentSession: HiveSession) + extends GetCatalogsOperation(parentSession) { + + override def runInternal(): Unit = {} + + override def getHandle: OperationHandle = { + val uuid: UUID = UUID.randomUUID() + val tHandleIdentifier: THandleIdentifier = new THandleIdentifier() + tHandleIdentifier.setGuid(getByteBufferFromUUID(uuid)) + tHandleIdentifier.setSecret(getByteBufferFromUUID(uuid)) + val tOperationHandle: TOperationHandle = new TOperationHandle() + tOperationHandle.setOperationId(tHandleIdentifier) + tOperationHandle.setOperationType(TOperationType.GET_TYPE_INFO) + tOperationHandle.setHasResultSetIsSet(false) + new OperationHandle(tOperationHandle) } + + private def getByteBufferFromUUID(uuid: UUID): Array[Byte] = { + val bb: ByteBuffer = ByteBuffer.wrap(new Array[Byte](16)) + bb.putLong(uuid.getMostSignificantBits) + bb.putLong(uuid.getLeastSignificantBits) + bb.array + } +} + +class OperationManagerMock extends OperationManager { + private val calledHandles: mutable.Set[OperationHandle] = new mutable.HashSet[OperationHandle]() + + override def newGetCatalogsOperation(parentSession: HiveSession): GetCatalogsOperation = { + val operation = new GetCatalogsOperationMock(parentSession) + try { + val m = classOf[OperationManager].getDeclaredMethod("addOperation", classOf[Operation]) + m.setAccessible(true) + m.invoke(this, operation) + } catch { + case e@(_: NoSuchMethodException | _: IllegalAccessException | + _: InvocationTargetException) => + throw new RuntimeException(e) + } + operation + } + + override def closeOperation(opHandle: OperationHandle): Unit = { + calledHandles.add(opHandle) + throw new RuntimeException + } + + def getCalledHandles: mutable.Set[OperationHandle] = calledHandles } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala index 13df3fabc4919..4c2f29e0bf394 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala @@ -17,10 +17,25 @@ package org.apache.spark.sql.hive.thriftserver +import java.util +import java.util.concurrent.Semaphore + +import scala.concurrent.duration._ + +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hive.service.cli.OperationState +import org.apache.hive.service.cli.session.{HiveSession, HiveSessionImpl} +import org.mockito.Mockito.{doReturn, mock, spy, when, RETURNS_DEEP_STUBS} +import org.mockito.invocation.InvocationOnMock + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.hive.thriftserver.ui.HiveThriftServer2EventManager +import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, NullType, StringType, StructField, StructType} -class SparkExecuteStatementOperationSuite extends SparkFunSuite { +class SparkExecuteStatementOperationSuite extends SparkFunSuite with SharedSparkSession { + test("SPARK-17112 `select null` via JDBC triggers IllegalArgumentException in ThriftServer") { val field1 = StructField("NULL", NullType) val field2 = StructField("(IF(true, NULL, NULL))", NullType) @@ -42,4 +57,68 @@ class SparkExecuteStatementOperationSuite extends SparkFunSuite { assert(columns.get(1).getType().getName == "INT") assert(columns.get(1).getComment() == "") } + + Seq( + (OperationState.CANCELED, (_: SparkExecuteStatementOperation).cancel()), + (OperationState.CLOSED, (_: SparkExecuteStatementOperation).close()) + ).foreach { case (finalState, transition) => + test("SPARK-32057 SparkExecuteStatementOperation should not transiently become ERROR " + + s"before being set to $finalState") { + val hiveSession = new HiveSessionImpl(ThriftserverShimUtils.testedProtocolVersions.head, + "username", "password", new HiveConf, "ip address") + hiveSession.open(new util.HashMap) + + HiveThriftServer2.eventManager = mock(classOf[HiveThriftServer2EventManager]) + + val spySqlContext = spy(sqlContext) + + // When cancel() is called on the operation, cleanup causes an exception to be thrown inside + // of execute(). This should not cause the state to become ERROR. The exception here will be + // triggered in our custom cleanup(). + val signal = new Semaphore(0) + val dataFrame = mock(classOf[DataFrame], RETURNS_DEEP_STUBS) + when(dataFrame.collect()).thenAnswer((_: InvocationOnMock) => { + signal.acquire() + throw new RuntimeException("Operation was cancelled by test cleanup.") + }) + val statement = "stmt" + doReturn(dataFrame, Nil: _*).when(spySqlContext).sql(statement) + + val executeStatementOperation = new MySparkExecuteStatementOperation(spySqlContext, + hiveSession, statement, signal, finalState) + + val run = new Thread() { + override def run(): Unit = executeStatementOperation.runInternal() + } + assert(executeStatementOperation.getStatus.getState === OperationState.INITIALIZED) + run.start() + eventually(timeout(5.seconds)) { + assert(executeStatementOperation.getStatus.getState === OperationState.RUNNING) + } + transition(executeStatementOperation) + run.join() + assert(executeStatementOperation.getStatus.getState === finalState) + } + } + + private class MySparkExecuteStatementOperation( + sqlContext: SQLContext, + hiveSession: HiveSession, + statement: String, + signal: Semaphore, + finalState: OperationState) + extends SparkExecuteStatementOperation(sqlContext, hiveSession, statement, + new util.HashMap, false) { + + override def cleanup(): Unit = { + super.cleanup() + signal.release() + // At this point, operation should already be in finalState (set by either close() or + // cancel()). We want to check if it stays in finalState after the exception thrown by + // releasing the semaphore propagates. We hence need to sleep for a short while. + Thread.sleep(1000) + // State should not be ERROR + assert(getStatus.getState === finalState) + } + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala index 686dc1c9bad6b..30d911becdba7 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala @@ -67,7 +67,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServ } /** List of test cases to ignore, in lower cases. */ - override def blackList: Set[String] = super.blackList ++ Set( + override def ignoreList: Set[String] = super.ignoreList ++ Set( // Missing UDF "postgreSQL/boolean.sql", "postgreSQL/case.sql", @@ -208,7 +208,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServ } override def createScalaTestCase(testCase: TestCase): Unit = { - if (blackList.exists(t => + if (ignoreList.exists(t => testCase.name.toLowerCase(Locale.ROOT).contains(t.toLowerCase(Locale.ROOT)))) { // Create a test case to ignore this case. ignore(testCase.name) { /* Do nothing */ } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 82af7dceb27f2..b7ea0630dd85f 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -26,10 +26,12 @@ import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy +import org.apache.spark.tags.SlowHiveTest /** * Runs the test cases that are included in the hive distribution. */ +@SlowHiveTest class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath private lazy val hiveQueryDir = TestHive.getHiveFile( @@ -83,7 +85,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { } /** A list of tests deemed out of scope currently and thus completely disregarded. */ - override def blackList: Seq[String] = Seq( + override def excludeList: Seq[String] = Seq( // These tests use hooks that are not on the classpath and thus break all subsequent execution. "hook_order", "hook_context_cs", @@ -514,7 +516,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // This test uses CREATE EXTERNAL TABLE without specifying LOCATION "alter2", - // [SPARK-16248][SQL] Whitelist the list of Hive fallback functions + // [SPARK-16248][SQL] Include the list of Hive fallback functions "udf_field", "udf_reflect2", "udf_xpath", @@ -602,7 +604,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_radians" ) - private def commonWhiteList = Seq( + private def commonIncludeList = Seq( "add_part_exist", "add_part_multiple", "add_partition_no_whitelist", @@ -1140,14 +1142,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { ) /** - * The set of tests that are believed to be working in catalyst. Tests not on whiteList or - * blacklist are implicitly marked as ignored. + * The set of tests that are believed to be working in catalyst. Tests not on includeList or + * excludeList are implicitly marked as ignored. */ - override def whiteList: Seq[String] = if (HiveUtils.isHive23) { - commonWhiteList ++ Seq( + override def includeList: Seq[String] = if (HiveUtils.isHive23) { + commonIncludeList ++ Seq( "decimal_1_1" ) } else { - commonWhiteList + commonIncludeList } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 2c0970c85449f..1b801ad69564c 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -763,7 +763,7 @@ class HiveWindowFunctionQueryFileSuite } } - override def blackList: Seq[String] = Seq( + override def excludeList: Seq[String] = Seq( // Partitioned table functions are not supported. "ptf*", // tests of windowing.q are in HiveWindowFunctionQueryBaseSuite @@ -791,12 +791,12 @@ class HiveWindowFunctionQueryFileSuite "windowing_adjust_rowcontainer_sz" ) - override def whiteList: Seq[String] = Seq( + override def includeList: Seq[String] = Seq( "windowing_udaf2" ) - // Only run those query tests in the realWhileList (do not try other ignored query files). + // Only run those query tests in the realIncludeList (do not try other ignored query files). override def testCases: Seq[(String, File)] = super.testCases.filter { - case (name, _) => realWhiteList.contains(name) + case (name, _) => realIncludeList.contains(name) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 2faf42028f3a2..f01a03996821a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -61,6 +61,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat import HiveExternalCatalog._ import CatalogTableType._ + // SPARK-32256: Make sure `VersionInfo` is initialized before touching the isolated classloader. + // This is to ensure Hive can get the Hadoop version when using the isolated classloader. + org.apache.hadoop.util.VersionInfo.getVersion() + /** * A Hive client used to interact with the metastore. */ @@ -829,8 +833,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat updateLocationInStorageProps(table, newPath = None).copy( locationUri = tableLocation.map(CatalogUtils.stringToURI(_))) } - val storageWithoutHiveGeneratedProperties = storageWithLocation.copy( - properties = storageWithLocation.properties.filterKeys(!HIVE_GENERATED_STORAGE_PROPERTIES(_))) + val storageWithoutHiveGeneratedProperties = storageWithLocation.copy(properties = + storageWithLocation.properties.filterKeys(!HIVE_GENERATED_STORAGE_PROPERTIES(_)).toMap) val partitionProvider = table.properties.get(TABLE_PARTITION_PROVIDER) val schemaFromTableProps = getSchemaFromTableProperties(table) @@ -844,7 +848,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat partitionColumnNames = partColumnNames, bucketSpec = getBucketSpecFromTableProperties(table), tracksPartitionsInCatalog = partitionProvider == Some(TABLE_PARTITION_PROVIDER_CATALOG), - properties = table.properties.filterKeys(!HIVE_GENERATED_TABLE_PROPERTIES(_))) + properties = table.properties.filterKeys(!HIVE_GENERATED_TABLE_PROPERTIES(_)).toMap) } override def tableExists(db: String, table: String): Boolean = withClient { @@ -1121,7 +1125,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val colStats = new mutable.HashMap[String, CatalogColumnStat] val colStatsProps = properties.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)).map { case (k, v) => k.drop(STATISTICS_COL_STATS_PREFIX.length) -> v - } + }.toMap // Find all the column names by matching the KEY_VERSION properties for them. colStatsProps.keys.filter { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 16e9014340244..19aa5935a09d7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -694,7 +694,7 @@ private[hive] trait HiveInspectors { } data: Any => { if (data != null) { - InternalRow.fromSeq(unwrappers.map(_(data))) + InternalRow.fromSeq(unwrappers.map(_(data)).toSeq) } else { null } @@ -872,7 +872,7 @@ private[hive] trait HiveInspectors { StructType(s.getAllStructFieldRefs.asScala.map(f => types.StructField( f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true) - )) + ).toSeq) case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) case m: MapObjectInspector => MapType( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 2981e391c0439..a89243c331c7b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -131,12 +131,12 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log // Consider table and storage properties. For properties existing in both sides, storage // properties will supersede table properties. if (serde.contains("parquet")) { - val options = relation.tableMeta.properties.filterKeys(isParquetProperty) ++ + val options = relation.tableMeta.properties.filterKeys(isParquetProperty).toMap ++ relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") } else { - val options = relation.tableMeta.properties.filterKeys(isOrcProperty) ++ + val options = relation.tableMeta.properties.filterKeys(isOrcProperty).toMap ++ relation.tableMeta.storage.properties if (SQLConf.get.getConf(SQLConf.ORC_IMPLEMENTATION) == "native") { convertToLogicalRelation( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 64726755237a6..78ec2b8e2047e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{SparkOptimizer, SparkPlanner} +import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.command.CommandCheck import org.apache.spark.sql.execution.datasources._ @@ -76,6 +77,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: + ResolveEncodersInScalaAgg +: new ResolveSessionCatalog( catalogManager, conf, catalog.isTempView, catalog.isTempFunction) +: customResolutionRules @@ -109,7 +111,8 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session override val sparkSession: SparkSession = session override def extraPlanningStrategies: Seq[Strategy] = - super.extraPlanningStrategies ++ customPlanningStrategies ++ Seq(HiveTableScans, Scripts) + super.extraPlanningStrategies ++ customPlanningStrategies ++ + Seq(HiveTableScans, HiveScripts) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index b9c98f4ea15e9..dae68df08f32e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -28,10 +28,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, ScriptTransformation, Statistics} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.CatalogV2Util.assertNoNullTypeInSchema import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.hive.execution._ +import org.apache.spark.sql.hive.execution.{HiveScriptIOSchema, HiveScriptTransformationExec} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -225,6 +227,8 @@ case class RelationConversions( isConvertible(tableDesc) && SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_CTAS) => // validation is required to be done here before relation conversion. DDLUtils.checkDataColNames(tableDesc.copy(schema = query.schema)) + // This is for CREATE TABLE .. STORED AS PARQUET/ORC AS SELECT null + assertNoNullTypeInSchema(query.schema) OptimizedCreateHiveTableAsSelectCommand( tableDesc, query, query.output.map(_.name), mode) } @@ -237,11 +241,11 @@ private[hive] trait HiveStrategies { val sparkSession: SparkSession - object Scripts extends Strategy { + object HiveScripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ScriptTransformation(input, script, output, child, ioschema) => val hiveIoSchema = HiveScriptIOSchema(ioschema) - ScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil + HiveScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil case _ => Nil } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 04caf57efdc74..62ff2db2ecb3c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -408,7 +408,7 @@ private[spark] object HiveUtils extends Logging { logWarning(s"Hive jar path '$path' does not exist.") Nil } else { - files.filter(_.getName.toLowerCase(Locale.ROOT).endsWith(".jar")) + files.filter(_.getName.toLowerCase(Locale.ROOT).endsWith(".jar")).toSeq } case path => new File(path) :: Nil @@ -505,7 +505,7 @@ private[spark] object HiveUtils extends Logging { // partition columns are part of the schema val partCols = hiveTable.getPartCols.asScala.map(HiveClientImpl.fromHiveColumn) val dataCols = hiveTable.getCols.asScala.map(HiveClientImpl.fromHiveColumn) - table.copy(schema = StructType(dataCols ++ partCols)) + table.copy(schema = StructType((dataCols ++ partCols).toSeq)) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 4d18eb6289418..3e0d44160c8a1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -308,7 +308,7 @@ class HadoopTableReader( /** * Creates a HadoopRDD based on the broadcasted HiveConf and other job properties that will be - * applied locally on each slave. + * applied locally on each executor. */ private def createOldHadoopRDD(tableDesc: TableDesc, path: String): RDD[Writable] = { val initializeJobConfFunc = HadoopTableReader.initializeLocalJobConfFunc(path, tableDesc) _ @@ -330,7 +330,7 @@ class HadoopTableReader( /** * Creates a NewHadoopRDD based on the broadcasted HiveConf and other job properties that will be - * applied locally on each slave. + * applied locally on each executor. */ private def createNewHadoopRDD(tableDesc: TableDesc, path: String): RDD[Writable] = { val newJobConf = new JobConf(hadoopConf) @@ -486,13 +486,19 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { var i = 0 val length = fieldRefs.length while (i < length) { - val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) - if (fieldValue == null) { - mutableRow.setNullAt(fieldOrdinals(i)) - } else { - unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + try { + val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 + } catch { + case ex: Throwable => + logError(s"Exception thrown in field <${fieldRefs(i).getFieldName}>") + throw ex } - i += 1 } mutableRow: InternalRow diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 6ad5e9d3c9080..3f70387a3b058 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -390,7 +390,7 @@ private[hive] class HiveClientImpl( } override def listDatabases(pattern: String): Seq[String] = withHiveState { - client.getDatabasesByPattern(pattern).asScala + client.getDatabasesByPattern(pattern).asScala.toSeq } private def getRawTableOption(dbName: String, tableName: String): Option[HiveTable] = { @@ -400,7 +400,7 @@ private[hive] class HiveClientImpl( private def getRawTablesByName(dbName: String, tableNames: Seq[String]): Seq[HiveTable] = { try { msClient.getTableObjectsByName(dbName, tableNames.asJava).asScala - .map(extraFixesForNonView).map(new HiveTable(_)) + .map(extraFixesForNonView).map(new HiveTable(_)).toSeq } catch { case ex: Exception => throw new HiveException(s"Unable to fetch tables of db $dbName", ex); @@ -434,7 +434,7 @@ private[hive] class HiveClientImpl( throw new SparkException( s"${ex.getMessage}, db: ${h.getDbName}, table: ${h.getTableName}", ex) } - val schema = StructType(cols ++ partCols) + val schema = StructType((cols ++ partCols).toSeq) val bucketSpec = if (h.getNumBuckets > 0) { val sortColumnOrders = h.getSortCols.asScala @@ -450,7 +450,7 @@ private[hive] class HiveClientImpl( } else { Seq.empty } - Option(BucketSpec(h.getNumBuckets, h.getBucketCols.asScala, sortColumnNames)) + Option(BucketSpec(h.getNumBuckets, h.getBucketCols.asScala.toSeq, sortColumnNames.toSeq)) } else { None } @@ -502,7 +502,7 @@ private[hive] class HiveClientImpl( throw new AnalysisException(s"Hive $tableTypeStr is not supported.") }, schema = schema, - partitionColumnNames = partCols.map(_.name), + partitionColumnNames = partCols.map(_.name).toSeq, // If the table is written by Spark, we will put bucketing information in table properties, // and will always overwrite the bucket spec in hive metastore by the bucketing information // in table properties. This means, if we have bucket spec in both hive metastore and @@ -539,7 +539,7 @@ private[hive] class HiveClientImpl( // that created by older versions of Spark. viewOriginalText = Option(h.getViewOriginalText), viewText = Option(h.getViewExpandedText), - unsupportedFeatures = unsupportedFeatures, + unsupportedFeatures = unsupportedFeatures.toSeq, ignoredProperties = ignoredProperties.toMap) } @@ -638,7 +638,7 @@ private[hive] class HiveClientImpl( shim.dropPartition(client, db, table, partition, !retainData, purge) } catch { case e: Exception => - val remainingParts = matchingParts.toBuffer -- droppedParts + val remainingParts = matchingParts.toBuffer --= droppedParts logError( s""" |====================== @@ -708,7 +708,7 @@ private[hive] class HiveClientImpl( assert(s.values.forall(_.nonEmpty), s"partition spec '$s' is invalid") client.getPartitionNames(table.database, table.identifier.table, s.asJava, -1) } - hivePartitionNames.asScala.sorted + hivePartitionNames.asScala.sorted.toSeq } override def getPartitionOption( @@ -735,7 +735,7 @@ private[hive] class HiveClientImpl( } val parts = client.getPartitions(hiveTable, partSpec.asJava).asScala.map(fromHivePartition) HiveCatalogMetrics.incrementFetchedPartitions(parts.length) - parts + parts.toSeq } override def getPartitionsByFilter( @@ -748,11 +748,11 @@ private[hive] class HiveClientImpl( } override def listTables(dbName: String): Seq[String] = withHiveState { - client.getAllTables(dbName).asScala + client.getAllTables(dbName).asScala.toSeq } override def listTables(dbName: String, pattern: String): Seq[String] = withHiveState { - client.getTablesByPattern(dbName, pattern).asScala + client.getTablesByPattern(dbName, pattern).asScala.toSeq } override def listTablesByType( @@ -766,7 +766,7 @@ private[hive] class HiveClientImpl( case _: UnsupportedOperationException => // Fallback to filter logic if getTablesByType not supported. val tableNames = client.getTablesByPattern(dbName, pattern).asScala - val tables = getTablesByName(dbName, tableNames).filter(_.tableType == tableType) + val tables = getTablesByName(dbName, tableNames.toSeq).filter(_.tableType == tableType) tables.map(_.identifier.table) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 8df43b785759e..8ff7a1abd2d6b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -363,7 +363,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { override def getDriverResults(driver: Driver): Seq[String] = { val res = new JArrayList[String]() getDriverResultsMethod.invoke(driver, res) - res.asScala + res.asScala.toSeq } override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { @@ -600,7 +600,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } FunctionResource(FunctionResourceType.fromString(resourceType), uri.getUri()) } - CatalogFunction(name, hf.getClassName, resources) + CatalogFunction(name, hf.getClassName, resources.toSeq) } override def getFunctionOption(hive: Hive, db: String, name: String): Option[CatalogFunction] = { @@ -623,7 +623,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } override def listFunctions(hive: Hive, db: String, pattern: String): Seq[String] = { - hive.getFunctions(db, pattern).asScala + hive.getFunctions(db, pattern).asScala.toSeq } /** @@ -843,7 +843,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { case s: String => s case a: Array[Object] => a(0).asInstanceOf[String] } - } + }.toSeq } override def getDatabaseOwnerName(db: Database): String = { @@ -1252,7 +1252,7 @@ private[client] class Shim_v2_3 extends Shim_v2_1 { pattern: String, tableType: TableType): Seq[String] = { getTablesByTypeMethod.invoke(hive, dbName, pattern, tableType) - .asInstanceOf[JList[String]].asScala + .asInstanceOf[JList[String]].asScala.toSeq } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala index 802ddafdbee4d..7b51618772edc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala @@ -87,7 +87,7 @@ class HiveOptions(@transient private val parameters: CaseInsensitiveMap[String]) def serdeProperties: Map[String, String] = parameters.filterKeys { k => !lowerCasedOptionNames.contains(k.toLowerCase(Locale.ROOT)) - }.map { case (k, v) => delimiterOptions.getOrElse(k, k) -> v } + }.map { case (k, v) => delimiterOptions.getOrElse(k, k) -> v }.toMap } object HiveOptions { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 4dccacef337e9..41820b0135f4a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -156,7 +156,7 @@ case class HiveTableScanExec( // Only partitioned values are needed here, since the predicate has already been bound to // partition key attribute references. - val row = InternalRow.fromSeq(castedValues) + val row = InternalRow.fromSeq(castedValues.toSeq) shouldKeep.eval(row).asInstanceOf[Boolean] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala index c7183fd7385a6..96fe646d39fde 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.hive.execution import java.io._ import java.nio.charset.StandardCharsets import java.util.Properties -import java.util.concurrent.TimeUnit import javax.annotation.Nullable import scala.collection.JavaConverters._ @@ -33,19 +32,15 @@ import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.io.Writable -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD +import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema -import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveInspectors import org.apache.spark.sql.hive.HiveShim._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType -import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} +import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} /** * Transforms the input by forking and running the specified script. @@ -54,301 +49,211 @@ import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfig * @param script the command that should be executed. * @param output the attributes that are produced by the script. */ -case class ScriptTransformationExec( +case class HiveScriptTransformationExec( input: Seq[Expression], script: String, output: Seq[Attribute], child: SparkPlan, ioschema: HiveScriptIOSchema) - extends UnaryExecNode { - - override def producedAttributes: AttributeSet = outputSet -- inputSet - - override def outputPartitioning: Partitioning = child.outputPartitioning - - protected override def doExecute(): RDD[InternalRow] = { - def processIterator(inputIterator: Iterator[InternalRow], hadoopConf: Configuration) - : Iterator[InternalRow] = { - val cmd = List("/bin/bash", "-c", script) - val builder = new ProcessBuilder(cmd.asJava) - - val proc = builder.start() - val inputStream = proc.getInputStream - val outputStream = proc.getOutputStream - val errorStream = proc.getErrorStream - - // In order to avoid deadlocks, we need to consume the error output of the child process. - // To avoid issues caused by large error output, we use a circular buffer to limit the amount - // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang - // that motivates this. - val stderrBuffer = new CircularBuffer(2048) - new RedirectThread( - errorStream, - stderrBuffer, - "Thread-ScriptTransformation-STDERR-Consumer").start() - - val outputProjection = new InterpretedProjection(input, child.output) - - // This nullability is a performance optimization in order to avoid an Option.foreach() call - // inside of a loop - @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null)) - - // This new thread will consume the ScriptTransformation's input rows and write them to the - // external process. That process's output will be read by this current thread. - val writerThread = new ScriptTransformationWriterThread( - inputIterator.map(outputProjection), - input.map(_.dataType), - inputSerde, - inputSoi, - ioschema, - outputStream, - proc, - stderrBuffer, - TaskContext.get(), - hadoopConf - ) - - // This nullability is a performance optimization in order to avoid an Option.foreach() call - // inside of a loop - @Nullable val (outputSerde, outputSoi) = { - ioschema.initOutputSerDe(output).getOrElse((null, null)) - } - - val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) - val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { - var curLine: String = null - val scriptOutputStream = new DataInputStream(inputStream) - - @Nullable val scriptOutputReader = - ioschema.recordReader(scriptOutputStream, hadoopConf).orNull - - var scriptOutputWritable: Writable = null - val reusedWritableObject: Writable = if (null != outputSerde) { - outputSerde.getSerializedClass().getConstructor().newInstance() - } else { - null - } - val mutableRow = new SpecificInternalRow(output.map(_.dataType)) + extends BaseScriptTransformationExec { + + override def processIterator( + inputIterator: Iterator[InternalRow], + hadoopConf: Configuration): Iterator[InternalRow] = { + val cmd = List("/bin/bash", "-c", script) + val builder = new ProcessBuilder(cmd.asJava) + + val proc = builder.start() + val inputStream = proc.getInputStream + val outputStream = proc.getOutputStream + val errorStream = proc.getErrorStream + + // In order to avoid deadlocks, we need to consume the error output of the child process. + // To avoid issues caused by large error output, we use a circular buffer to limit the amount + // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang + // that motivates this. + val stderrBuffer = new CircularBuffer(2048) + new RedirectThread( + errorStream, + stderrBuffer, + "Thread-ScriptTransformation-STDERR-Consumer").start() + + val outputProjection = new InterpretedProjection(input, child.output) + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null)) + + // This new thread will consume the ScriptTransformation's input rows and write them to the + // external process. That process's output will be read by this current thread. + val writerThread = new HiveScriptTransformationWriterThread( + inputIterator.map(outputProjection), + input.map(_.dataType), + inputSerde, + inputSoi, + ioschema, + outputStream, + proc, + stderrBuffer, + TaskContext.get(), + hadoopConf + ) + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (outputSerde, outputSoi) = { + ioschema.initOutputSerDe(output).getOrElse((null, null)) + } - @transient - lazy val unwrappers = outputSoi.getAllStructFieldRefs.asScala.map(unwrapperFor) + val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) + val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { + var curLine: String = null + val scriptOutputStream = new DataInputStream(inputStream) - private def checkFailureAndPropagate(cause: Throwable = null): Unit = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } + @Nullable val scriptOutputReader = + ioschema.recordReader(scriptOutputStream, hadoopConf).orNull - // There can be a lag between reader read EOF and the process termination. - // If the script fails to startup, this kind of error may be missed. - // So explicitly waiting for the process termination. - val timeout = conf.getConf(SQLConf.SCRIPT_TRANSFORMATION_EXIT_TIMEOUT) - val exitRes = proc.waitFor(timeout, TimeUnit.SECONDS) - if (!exitRes) { - log.warn(s"Transformation script process exits timeout in $timeout seconds") - } + var scriptOutputWritable: Writable = null + val reusedWritableObject: Writable = if (null != outputSerde) { + outputSerde.getSerializedClass().getConstructor().newInstance() + } else { + null + } + val mutableRow = new SpecificInternalRow(output.map(_.dataType)) - if (!proc.isAlive) { - val exitCode = proc.exitValue() - if (exitCode != 0) { - logError(stderrBuffer.toString) // log the stderr circular buffer - throw new SparkException(s"Subprocess exited with status $exitCode. " + - s"Error: ${stderrBuffer.toString}", cause) - } - } - } + @transient + lazy val unwrappers = outputSoi.getAllStructFieldRefs.asScala.map(unwrapperFor) - override def hasNext: Boolean = { - try { - if (outputSerde == null) { + override def hasNext: Boolean = { + try { + if (outputSerde == null) { + if (curLine == null) { + curLine = reader.readLine() if (curLine == null) { - curLine = reader.readLine() - if (curLine == null) { - checkFailureAndPropagate() - return false - } + checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) + return false } - } else if (scriptOutputWritable == null) { - scriptOutputWritable = reusedWritableObject + } + } else if (scriptOutputWritable == null) { + scriptOutputWritable = reusedWritableObject - if (scriptOutputReader != null) { - if (scriptOutputReader.next(scriptOutputWritable) <= 0) { - checkFailureAndPropagate() + if (scriptOutputReader != null) { + if (scriptOutputReader.next(scriptOutputWritable) <= 0) { + checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) + return false + } + } else { + try { + scriptOutputWritable.readFields(scriptOutputStream) + } catch { + case _: EOFException => + // This means that the stdout of `proc` (ie. TRANSFORM process) has exhausted. + // Ideally the proc should *not* be alive at this point but + // there can be a lag between EOF being written out and the process + // being terminated. So explicitly waiting for the process to be done. + checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) return false - } - } else { - try { - scriptOutputWritable.readFields(scriptOutputStream) - } catch { - case _: EOFException => - // This means that the stdout of `proc` (ie. TRANSFORM process) has exhausted. - // Ideally the proc should *not* be alive at this point but - // there can be a lag between EOF being written out and the process - // being terminated. So explicitly waiting for the process to be done. - checkFailureAndPropagate() - return false - } } } + } - true - } catch { - case NonFatal(e) => - // If this exception is due to abrupt / unclean termination of `proc`, - // then detect it and propagate a better exception message for end users - checkFailureAndPropagate(e) + true + } catch { + case NonFatal(e) => + // If this exception is due to abrupt / unclean termination of `proc`, + // then detect it and propagate a better exception message for end users + checkFailureAndPropagate(writerThread, e, proc, stderrBuffer) - throw e - } + throw e } + } - override def next(): InternalRow = { - if (!hasNext) { - throw new NoSuchElementException + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException + } + if (outputSerde == null) { + val prevLine = curLine + curLine = reader.readLine() + if (!ioschema.schemaLess) { + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + .map(CatalystTypeConverters.convertToCatalyst)) + } else { + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + .map(CatalystTypeConverters.convertToCatalyst)) } - if (outputSerde == null) { - val prevLine = curLine - curLine = reader.readLine() - if (!ioschema.schemaLess) { - new GenericInternalRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - .map(CatalystTypeConverters.convertToCatalyst)) + } else { + val raw = outputSerde.deserialize(scriptOutputWritable) + scriptOutputWritable = null + val dataList = outputSoi.getStructFieldsDataAsList(raw) + var i = 0 + while (i < dataList.size()) { + if (dataList.get(i) == null) { + mutableRow.setNullAt(i) } else { - new GenericInternalRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) - .map(CatalystTypeConverters.convertToCatalyst)) - } - } else { - val raw = outputSerde.deserialize(scriptOutputWritable) - scriptOutputWritable = null - val dataList = outputSoi.getStructFieldsDataAsList(raw) - var i = 0 - while (i < dataList.size()) { - if (dataList.get(i) == null) { - mutableRow.setNullAt(i) - } else { - unwrappers(i)(dataList.get(i), mutableRow, i) - } - i += 1 + unwrappers(i)(dataList.get(i), mutableRow, i) } - mutableRow + i += 1 } + mutableRow } } - - writerThread.start() - - outputIterator } - val broadcastedHadoopConf = - new SerializableConfiguration(sqlContext.sessionState.newHadoopConf()) + writerThread.start() - child.execute().mapPartitions { iter => - if (iter.hasNext) { - val proj = UnsafeProjection.create(schema) - processIterator(iter, broadcastedHadoopConf.value).map(proj) - } else { - // If the input iterator has no rows then do not launch the external script. - Iterator.empty - } - } + outputIterator } } -private class ScriptTransformationWriterThread( +private class HiveScriptTransformationWriterThread( iter: Iterator[InternalRow], inputSchema: Seq[DataType], @Nullable inputSerde: AbstractSerDe, @Nullable inputSoi: StructObjectInspector, - ioschema: HiveScriptIOSchema, + ioSchema: HiveScriptIOSchema, outputStream: OutputStream, proc: Process, stderrBuffer: CircularBuffer, taskContext: TaskContext, - conf: Configuration - ) extends Thread("Thread-ScriptTransformation-Feed") with HiveInspectors with Logging { - - setDaemon(true) - - @volatile private var _exception: Throwable = null - - /** Contains the exception thrown while writing the parent iterator to the external process. */ - def exception: Option[Throwable] = Option(_exception) - - override def run(): Unit = Utils.logUncaughtExceptions { - TaskContext.setTaskContext(taskContext) - + conf: Configuration) + extends BaseScriptTransformationWriterThread( + iter, + inputSchema, + ioSchema, + outputStream, + proc, + stderrBuffer, + taskContext, + conf) with HiveInspectors { + + override def processRows(): Unit = { val dataOutputStream = new DataOutputStream(outputStream) - @Nullable val scriptInputWriter = ioschema.recordWriter(dataOutputStream, conf).orNull - - // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so - // let's use a variable to record whether the `finally` block was hit due to an exception - var threwException: Boolean = true - val len = inputSchema.length - try { - if (inputSerde == null) { - iter.foreach { row => - val data = if (len == 0) { - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") - } else { - val sb = new StringBuilder - sb.append(row.get(0, inputSchema(0))) - var i = 1 - while (i < len) { - sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - sb.append(row.get(i, inputSchema(i))) - i += 1 - } - sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) - sb.toString() - } - outputStream.write(data.getBytes(StandardCharsets.UTF_8)) + @Nullable val scriptInputWriter = ioSchema.recordWriter(dataOutputStream, conf).orNull + + if (inputSerde == null) { + processRowsWithoutSerde() + } else { + // Convert Spark InternalRows to hive data via `HiveInspectors.wrapperFor`. + val hiveData = new Array[Any](inputSchema.length) + val fieldOIs = inputSoi.getAllStructFieldRefs.asScala.map(_.getFieldObjectInspector).toArray + val wrappers = fieldOIs.zip(inputSchema).map { case (f, dt) => wrapperFor(f, dt) } + + iter.foreach { row => + var i = 0 + while (i < fieldOIs.length) { + hiveData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, inputSchema(i))) + i += 1 } - } else { - // Convert Spark InternalRows to hive data via `HiveInspectors.wrapperFor`. - val hiveData = new Array[Any](inputSchema.length) - val fieldOIs = inputSoi.getAllStructFieldRefs.asScala.map(_.getFieldObjectInspector).toArray - val wrappers = fieldOIs.zip(inputSchema).map { case (f, dt) => wrapperFor(f, dt) } - iter.foreach { row => - var i = 0 - while (i < fieldOIs.length) { - hiveData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, inputSchema(i))) - i += 1 - } - - val writable = inputSerde.serialize(hiveData, inputSoi) - if (scriptInputWriter != null) { - scriptInputWriter.write(writable) - } else { - prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) - } - } - } - threwException = false - } catch { - // SPARK-25158 Exception should not be thrown again, otherwise it will be captured by - // SparkUncaughtExceptionHandler, then Executor will exit because of this Uncaught Exception, - // so pass the exception to `ScriptTransformationExec` is enough. - case t: Throwable => - // An error occurred while writing input, so kill the child process. According to the - // Javadoc this call will not throw an exception: - _exception = t - proc.destroy() - logError("Thread-ScriptTransformation-Feed exit cause by: ", t) - } finally { - try { - Utils.tryLogNonFatalError(outputStream.close()) - if (proc.waitFor() != 0) { - logError(stderrBuffer.toString) // log the stderr circular buffer + val writable = inputSerde.serialize(hiveData, inputSoi) + if (scriptInputWriter != null) { + scriptInputWriter.write(writable) + } else { + prepareWritable(writable, ioSchema.outputSerdeProps).write(dataOutputStream) } - } catch { - case NonFatal(exceptionFromFinallyBlock) => - if (!threwException) { - throw exceptionFromFinallyBlock - } else { - log.error("Exception in finally block", exceptionFromFinallyBlock) - } } } } @@ -382,16 +287,7 @@ case class HiveScriptIOSchema ( recordReaderClass: Option[String], recordWriterClass: Option[String], schemaLess: Boolean) - extends HiveInspectors { - - private val defaultFormat = Map( - ("TOK_TABLEROWFORMATFIELD", "\t"), - ("TOK_TABLEROWFORMATLINES", "\n") - ) - - val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) - val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) - + extends BaseScriptTransformIOSchema with HiveInspectors { def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, StructObjectInspector)] = { inputSerdeClass.map { serdeClass => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 05d608a2016a5..8ad5cb70d248b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -224,7 +224,7 @@ private[hive] case class HiveGenericUDTF( override lazy val elementSchema = StructType(outputInspector.getAllStructFieldRefs.asScala.map { field => StructField(field.getFieldName, inspectorToDataType(field.getFieldObjectInspector), nullable = true) - }) + }.toSeq) @transient private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray @@ -257,7 +257,7 @@ private[hive] case class HiveGenericUDTF( def collectRows(): Seq[InternalRow] = { val toCollect = collected collected = new ArrayBuffer[InternalRow] - toCollect + toCollect.toSeq } } diff --git a/sql/hive/src/test/resources/data/scripts/cat.py b/sql/hive/src/test/resources/data/scripts/cat.py index aea0362f899fa..420d9f832a184 100644 --- a/sql/hive/src/test/resources/data/scripts/cat.py +++ b/sql/hive/src/test/resources/data/scripts/cat.py @@ -16,7 +16,6 @@ # specific language governing permissions and limitations # under the License. # -from __future__ import print_function import sys import os diff --git a/sql/hive/src/test/resources/data/scripts/dumpdata_script.py b/sql/hive/src/test/resources/data/scripts/dumpdata_script.py index 5b360208d36f6..f724fdc85b177 100644 --- a/sql/hive/src/test/resources/data/scripts/dumpdata_script.py +++ b/sql/hive/src/test/resources/data/scripts/dumpdata_script.py @@ -18,12 +18,9 @@ # import sys -if sys.version_info[0] >= 3: - xrange = range - -for i in xrange(50): - for j in xrange(5): - for k in xrange(20022): +for i in range(50): + for j in range(5): + for k in range(20022): print(20000 * i + k) for line in sys.stdin: diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/add_partition_no_includelist.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/add_partition_no_includelist.q new file mode 100644 index 0000000000000..17677122a1bca --- /dev/null +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/add_partition_no_includelist.q @@ -0,0 +1,7 @@ +SET hive.metastore.partition.name.whitelist.pattern=; +-- Test with no partition name include-list pattern + +CREATE TABLE part_noincludelist_test (key STRING, value STRING) PARTITIONED BY (ds STRING); +SHOW PARTITIONS part_noincludelist_test; + +ALTER TABLE part_noincludelist_test ADD PARTITION (ds='1,2,3,4'); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/add_partition_no_whitelist.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/add_partition_no_whitelist.q deleted file mode 100644 index f51c53c2ff627..0000000000000 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/add_partition_no_whitelist.q +++ /dev/null @@ -1,7 +0,0 @@ -SET hive.metastore.partition.name.whitelist.pattern=; --- Test with no partition name whitelist pattern - -CREATE TABLE part_nowhitelist_test (key STRING, value STRING) PARTITIONED BY (ds STRING); -SHOW PARTITIONS part_nowhitelist_test; - -ALTER TABLE part_nowhitelist_test ADD PARTITION (ds='1,2,3,4'); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/add_partition_with_includelist.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/add_partition_with_includelist.q new file mode 100644 index 0000000000000..7e7f30dc37305 --- /dev/null +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/add_partition_with_includelist.q @@ -0,0 +1,7 @@ +SET hive.metastore.partition.name.whitelist.pattern=[A-Za-z]*; +-- This pattern matches only letters. + +CREATE TABLE part_includelist_test (key STRING, value STRING) PARTITIONED BY (ds STRING); +SHOW PARTITIONS part_includelist_test; + +ALTER TABLE part_includelist_test ADD PARTITION (ds='Part'); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/add_partition_with_whitelist.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/add_partition_with_whitelist.q deleted file mode 100644 index 009c7610ef917..0000000000000 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/add_partition_with_whitelist.q +++ /dev/null @@ -1,9 +0,0 @@ -SET hive.metastore.partition.name.whitelist.pattern=[A-Za-z]*; --- This pattern matches only letters. - -CREATE TABLE part_whitelist_test (key STRING, value STRING) PARTITIONED BY (ds STRING); -SHOW PARTITIONS part_whitelist_test; - -ALTER TABLE part_whitelist_test ADD PARTITION (ds='Part'); - - diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/alter_partition_with_includelist.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/alter_partition_with_includelist.q new file mode 100644 index 0000000000000..fcef12cbaac4e --- /dev/null +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/alter_partition_with_includelist.q @@ -0,0 +1,9 @@ +SET hive.metastore.partition.name.whitelist.pattern=[A-Za-z]*; +-- This pattern matches only letters. + +CREATE TABLE part_includelist_test (key STRING, value STRING) PARTITIONED BY (ds STRING); +SHOW PARTITIONS part_includelist_test; + +ALTER TABLE part_includelist_test ADD PARTITION (ds='Part'); + +ALTER TABLE part_includelist_test PARTITION (ds='Part') rename to partition (ds='Apart'); diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/alter_partition_with_whitelist.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/alter_partition_with_whitelist.q deleted file mode 100644 index 301362a881456..0000000000000 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/alter_partition_with_whitelist.q +++ /dev/null @@ -1,9 +0,0 @@ -SET hive.metastore.partition.name.whitelist.pattern=[A-Za-z]*; --- This pattern matches only letters. - -CREATE TABLE part_whitelist_test (key STRING, value STRING) PARTITIONED BY (ds STRING); -SHOW PARTITIONS part_whitelist_test; - -ALTER TABLE part_whitelist_test ADD PARTITION (ds='Part'); - -ALTER TABLE part_whitelist_test PARTITION (ds='Part') rename to partition (ds='Apart'); diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 8be3d26bfc93a..aa96fa035c4f0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.tags.ExtendedHiveTest +import org.apache.spark.tags.{ExtendedHiveTest, SlowHiveTest} import org.apache.spark.util.Utils /** @@ -46,6 +46,7 @@ import org.apache.spark.util.Utils * expected version under this local directory, e.g. `/tmp/spark-test/spark-2.0.3`, we will skip the * downloading for this spark version. */ +@SlowHiveTest @ExtendedHiveTest class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private val isTestAtLeastJava9 = SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowCreateTableSuite.scala index cfcf70c0e79f0..446923ad23201 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowCreateTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowCreateTableSuite.scala @@ -279,7 +279,7 @@ class HiveShowCreateTableSuite extends ShowCreateTableSuite with TestHiveSinglet table.copy( createTime = 0L, lastAccessTime = 0L, - properties = table.properties.filterKeys(!nondeterministicProps.contains(_)), + properties = table.properties.filterKeys(!nondeterministicProps.contains(_)).toMap, stats = None, ignoredProperties = Map.empty, storage = table.storage.copy(properties = Map.empty), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 8b97489e2d818..3a7e92ee1c00b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -38,12 +38,13 @@ import org.apache.spark.sql.hive.test.{HiveTestJars, TestHiveContext} import org.apache.spark.sql.internal.SQLConf.SHUFFLE_PARTITIONS import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH import org.apache.spark.sql.types.{DecimalType, StructType} -import org.apache.spark.tags.ExtendedHiveTest +import org.apache.spark.tags.{ExtendedHiveTest, SlowHiveTest} import org.apache.spark.util.{ResetSystemProperties, Utils} /** * This suite tests spark-submit with applications using HiveContext. */ +@SlowHiveTest @ExtendedHiveTest class HiveSparkSubmitSuite extends SparkSubmitTestUtils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index c1eab63ec073f..be6d023302293 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -911,7 +911,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto */ private def getStatsProperties(tableName: String): Map[String, String] = { val hTable = hiveClient.getTable(spark.sessionState.catalog.getCurrentDatabase, tableName) - hTable.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + hTable.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)).toMap } test("change stats after insert command for hive table") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HadoopVersionInfoSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HadoopVersionInfoSuite.scala new file mode 100644 index 0000000000000..65492abf38cc0 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HadoopVersionInfoSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.io.File +import java.net.URLClassLoader + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} +import org.apache.spark.util.Utils + +/** + * This test suite requires a clean JVM because it's testing the initialization of static codes in + * `org.apache.hadoop.util.VersionInfo`. + */ +class HadoopVersionInfoSuite extends SparkFunSuite { + override protected val enableAutoThreadAudit = false + + test("SPARK-32256: Hadoop VersionInfo should be preloaded") { + val ivyPath = + Utils.createTempDir(namePrefix = s"${classOf[HadoopVersionInfoSuite].getSimpleName}-ivy") + try { + val hadoopConf = new Configuration() + hadoopConf.set("test", "success") + hadoopConf.set("datanucleus.schema.autoCreateAll", "true") + hadoopConf.set("hive.metastore.schema.verification", "false") + + // Download jars for Hive 2.0 + val client = IsolatedClientLoader.forVersion( + hiveMetastoreVersion = "2.0", + hadoopVersion = "2.7.4", + sparkConf = new SparkConf(), + hadoopConf = hadoopConf, + config = HiveClientBuilder.buildConf(Map.empty), + ivyPath = Some(ivyPath.getCanonicalPath), + sharesHadoopClasses = true) + val jars = client.classLoader.getParent.asInstanceOf[URLClassLoader].getURLs + .map(u => new File(u.toURI)) + // Drop all Hadoop jars to use the existing Hadoop jars on the classpath + .filter(!_.getName.startsWith("org.apache.hadoop_hadoop-")) + + val sparkConf = new SparkConf() + sparkConf.set(HiveUtils.HIVE_METASTORE_VERSION, "2.0") + sparkConf.set( + HiveUtils.HIVE_METASTORE_JARS, + jars.map(_.getCanonicalPath).mkString(File.pathSeparator)) + HiveClientBuilder.buildConf(Map.empty).foreach { case (k, v) => + hadoopConf.set(k, v) + } + new HiveExternalCatalog(sparkConf, hadoopConf).client.getState + } finally { + Utils.deleteRecursively(ivyPath) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala index ab73f668c6ca6..2ad3afcb214b3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala @@ -33,7 +33,7 @@ private[client] object HiveClientBuilder { Some(new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath)) } - private def buildConf(extraConf: Map[String, String]) = { + private[client] def buildConf(extraConf: Map[String, String]): Map[String, String] = { lazy val warehousePath = Utils.createTempDir() lazy val metastorePath = Utils.createTempDir() metastorePath.delete() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 8642a5ff16812..c5c92ddad9014 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} import org.apache.spark.sql.hive.test.TestHiveVersion import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.StructType -import org.apache.spark.tags.ExtendedHiveTest +import org.apache.spark.tags.{ExtendedHiveTest, SlowHiveTest} import org.apache.spark.util.{MutableURLClassLoader, Utils} /** @@ -51,6 +51,7 @@ import org.apache.spark.util.{MutableURLClassLoader, Utils} * is not fully tested. */ // TODO: Refactor this to `HiveClientSuite` and make it a subclass of `HiveVersionSuite` +@SlowHiveTest @ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index fac981267f4d7..87771eed17b1b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.tags.SlowHiveTest import org.apache.spark.unsafe.UnsafeAlignedOffset @@ -780,7 +781,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te assert(math.abs(corr6 + 1.0) < 1e-12) // Test for udaf_corr in HiveCompatibilitySuite - // udaf_corr has been blacklisted due to numerical errors + // udaf_corr has been excluded due to numerical errors // We test it here: // SELECT corr(b, c) FROM covar_tab WHERE a < 1; => NULL // SELECT corr(b, c) FROM covar_tab WHERE a < 3; => NULL @@ -1054,6 +1055,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te class HashAggregationQuerySuite extends AggregationQuerySuite +@SlowHiveTest class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index e8cf4ad5d9f28..fbd1fc1ea98df 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.connector.FakeV2Provider import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.connector.catalog.SupportsNamespaces.PROP_OWNER import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} @@ -44,9 +45,11 @@ import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.tags.SlowHiveTest import org.apache.spark.util.Utils // TODO(gatorsmile): combine HiveCatalogedDDLSuite and HiveDDLSuite +@SlowHiveTest class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeAndAfterEach { override def afterEach(): Unit = { try { @@ -125,7 +128,7 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA createTime = 0L, lastAccessTime = 0L, owner = "", - properties = table.properties.filterKeys(!nondeterministicProps.contains(_)), + properties = table.properties.filterKeys(!nondeterministicProps.contains(_)).toMap, // View texts are checked separately viewText = None ) @@ -404,6 +407,7 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA } } +@SlowHiveTest class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton with BeforeAndAfterEach { import testImplicits._ @@ -2309,6 +2313,126 @@ class HiveDDLSuite } } + test("SPARK-20680: do not support for null column datatype") { + withTable("t") { + withView("tabNullType") { + hiveClient.runSqlHive("CREATE TABLE t (t1 int)") + hiveClient.runSqlHive("INSERT INTO t VALUES (3)") + hiveClient.runSqlHive("CREATE VIEW tabNullType AS SELECT NULL AS col FROM t") + checkAnswer(spark.table("tabNullType"), Row(null)) + // No exception shows + val desc = spark.sql("DESC tabNullType").collect().toSeq + assert(desc.contains(Row("col", NullType.simpleString, null))) + } + } + + // Forbid CTAS with null type + withTable("t1", "t2", "t3") { + val e1 = intercept[AnalysisException] { + spark.sql("CREATE TABLE t1 USING PARQUET AS SELECT null as null_col") + }.getMessage + assert(e1.contains("Cannot create tables with null type")) + + val e2 = intercept[AnalysisException] { + spark.sql("CREATE TABLE t2 AS SELECT null as null_col") + }.getMessage + assert(e2.contains("Cannot create tables with null type")) + + val e3 = intercept[AnalysisException] { + spark.sql("CREATE TABLE t3 STORED AS PARQUET AS SELECT null as null_col") + }.getMessage + assert(e3.contains("Cannot create tables with null type")) + } + + // Forbid Replace table AS SELECT with null type + withTable("t") { + val v2Source = classOf[FakeV2Provider].getName + val e = intercept[AnalysisException] { + spark.sql(s"CREATE OR REPLACE TABLE t USING $v2Source AS SELECT null as null_col") + }.getMessage + assert(e.contains("Cannot create tables with null type")) + } + + // Forbid creating table with VOID type in Spark + withTable("t1", "t2", "t3", "t4") { + val e1 = intercept[AnalysisException] { + spark.sql(s"CREATE TABLE t1 (v VOID) USING PARQUET") + }.getMessage + assert(e1.contains("Cannot create tables with null type")) + val e2 = intercept[AnalysisException] { + spark.sql(s"CREATE TABLE t2 (v VOID) USING hive") + }.getMessage + assert(e2.contains("Cannot create tables with null type")) + val e3 = intercept[AnalysisException] { + spark.sql(s"CREATE TABLE t3 (v VOID)") + }.getMessage + assert(e3.contains("Cannot create tables with null type")) + val e4 = intercept[AnalysisException] { + spark.sql(s"CREATE TABLE t4 (v VOID) STORED AS PARQUET") + }.getMessage + assert(e4.contains("Cannot create tables with null type")) + } + + // Forbid Replace table with VOID type + withTable("t") { + val v2Source = classOf[FakeV2Provider].getName + val e = intercept[AnalysisException] { + spark.sql(s"CREATE OR REPLACE TABLE t (v VOID) USING $v2Source") + }.getMessage + assert(e.contains("Cannot create tables with null type")) + } + + // Make sure spark.catalog.createTable with null type will fail + val schema1 = new StructType().add("c", NullType) + assertHiveTableNullType(schema1) + assertDSTableNullType(schema1) + + val schema2 = new StructType() + .add("c", StructType(Seq(StructField.apply("c1", NullType)))) + assertHiveTableNullType(schema2) + assertDSTableNullType(schema2) + + val schema3 = new StructType().add("c", ArrayType(NullType)) + assertHiveTableNullType(schema3) + assertDSTableNullType(schema3) + + val schema4 = new StructType() + .add("c", MapType(StringType, NullType)) + assertHiveTableNullType(schema4) + assertDSTableNullType(schema4) + + val schema5 = new StructType() + .add("c", MapType(NullType, StringType)) + assertHiveTableNullType(schema5) + assertDSTableNullType(schema5) + } + + private def assertHiveTableNullType(schema: StructType): Unit = { + withTable("t") { + val e = intercept[AnalysisException] { + spark.catalog.createTable( + tableName = "t", + source = "hive", + schema = schema, + options = Map("fileFormat" -> "parquet")) + }.getMessage + assert(e.contains("Cannot create tables with null type")) + } + } + + private def assertDSTableNullType(schema: StructType): Unit = { + withTable("t") { + val e = intercept[AnalysisException] { + spark.catalog.createTable( + tableName = "t", + source = "json", + schema = schema, + options = Map.empty[String, String]) + }.getMessage + assert(e.contains("Cannot create tables with null type")) + } + } + test("SPARK-21216: join with a streaming DataFrame") { import org.apache.spark.sql.execution.streaming.MemoryStream import testImplicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index bb4ce6d3aa3f1..192fff2b98879 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -25,18 +25,18 @@ import org.apache.spark.sql.catalyst.util._ * A framework for running the query tests that are listed as a set of text files. * * TestSuites that derive from this class must provide a map of testCaseName to testCaseFiles - * that should be included. Additionally, there is support for whitelisting and blacklisting + * that should be included. Additionally, there is support for including and excluding * tests as development progresses. */ abstract class HiveQueryFileTest extends HiveComparisonTest { /** A list of tests deemed out of scope and thus completely disregarded */ - def blackList: Seq[String] = Nil + def excludeList: Seq[String] = Nil /** - * The set of tests that are believed to be working in catalyst. Tests not in whiteList - * blacklist are implicitly marked as ignored. + * The set of tests that are believed to be working in catalyst. Tests not in includeList or + * excludeList are implicitly marked as ignored. */ - def whiteList: Seq[String] = ".*" :: Nil + def includeList: Seq[String] = ".*" :: Nil def testCases: Seq[(String, File)] @@ -45,25 +45,34 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { runOnlyDirectories.nonEmpty || skipDirectories.nonEmpty - val whiteListProperty: String = "spark.hive.whitelist" - // Allow the whiteList to be overridden by a system property - val realWhiteList: Seq[String] = - Option(System.getProperty(whiteListProperty)).map(_.split(",").toSeq).getOrElse(whiteList) + val deprecatedIncludeListProperty: String = "spark.hive.whitelist" + val includeListProperty: String = "spark.hive.includelist" + if (System.getProperty(deprecatedIncludeListProperty) != null) { + logWarning(s"System property `$deprecatedIncludeListProperty` is deprecated; please update " + + s"to use new property: $includeListProperty") + } + // Allow the includeList to be overridden by a system property + val realIncludeList: Seq[String] = + Option(System.getProperty(includeListProperty)) + .orElse(Option(System.getProperty(deprecatedIncludeListProperty))) + .map(_.split(",").toSeq) + .getOrElse(includeList) // Go through all the test cases and add them to scala test. testCases.sorted.foreach { case (testCaseName, testCaseFile) => - if (blackList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_)) { - logDebug(s"Blacklisted test skipped $testCaseName") - } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || + if (excludeList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_)) { + logDebug(s"Excluded test skipped $testCaseName") + } else if ( + realIncludeList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) createQueryTest(testCaseName, queriesString, reset = true, tryWithoutResettingFirst = true) } else { - // Only output warnings for the built in whitelist as this clutters the output when the user - // trying to execute a single test from the commandline. - if (System.getProperty(whiteListProperty) == null && !runAll) { + // Only output warnings for the built in includeList as this clutters the output when the + // user is trying to execute a single test from the commandline. + if (System.getProperty(includeListProperty) == null && !runAll) { ignore(testCaseName) {} } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index e5628c33b5ec8..cea7c5686054a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.hive.test.{HiveTestJars, TestHive} import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.tags.SlowHiveTest case class TestData(a: Int, b: String) @@ -46,6 +47,7 @@ case class TestData(a: Int, b: String) * A set of test cases expressed in Hive QL that are not covered by the tests * included in the hive distribution. */ +@SlowHiveTest class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAndAfter { import org.apache.spark.sql.hive.test.TestHive.implicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala similarity index 94% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index b97eb869a9e54..35252fc47f49f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StringType -class ScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with TestHiveSingleton +class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with TestHiveSingleton with BeforeAndAfterEach { import spark.implicits._ @@ -83,7 +83,7 @@ class ScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with Tes val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") checkAnswer( rowsDf, - (child: SparkPlan) => new ScriptTransformationExec( + (child: SparkPlan) => new HiveScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), @@ -100,7 +100,7 @@ class ScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with Tes val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") checkAnswer( rowsDf, - (child: SparkPlan) => new ScriptTransformationExec( + (child: SparkPlan) => new HiveScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), @@ -118,7 +118,7 @@ class ScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with Tes val e = intercept[TestFailedException] { checkAnswer( rowsDf, - (child: SparkPlan) => new ScriptTransformationExec( + (child: SparkPlan) => new HiveScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), @@ -139,7 +139,7 @@ class ScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with Tes val e = intercept[TestFailedException] { checkAnswer( rowsDf, - (child: SparkPlan) => new ScriptTransformationExec( + (child: SparkPlan) => new HiveScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), @@ -160,7 +160,7 @@ class ScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with Tes val e = intercept[SparkException] { val plan = - new ScriptTransformationExec( + new HiveScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "some_non_existent_command", output = Seq(AttributeReference("a", StringType)()), @@ -181,7 +181,7 @@ class ScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with Tes checkAnswer( rowsDf, - (child: SparkPlan) => new ScriptTransformationExec( + (child: SparkPlan) => new HiveScriptTransformationExec( input = Seq(rowsDf.col("name").expr), script = "cat", output = Seq(AttributeReference("name", StringType)()), @@ -234,7 +234,7 @@ class ScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with Tes val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") val e = intercept[SparkException] { val plan = - new ScriptTransformationExec( + new HiveScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "some_non_existent_command", output = Seq(AttributeReference("a", StringType)()), @@ -252,7 +252,7 @@ class ScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with Tes val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") val e = intercept[SparkException] { val plan = - new ScriptTransformationExec( + new HiveScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "some_non_existent_command", output = Seq(AttributeReference("a", StringType)()), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala index b20ef035594da..6f37e39a532d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala @@ -27,13 +27,14 @@ import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.sql.types._ -import org.apache.spark.tags.ExtendedHiveTest +import org.apache.spark.tags.{ExtendedHiveTest, SlowHiveTest} import org.apache.spark.util.Utils /** * A separate set of DDL tests that uses Hive 2.1 libraries, which behave a little differently * from the built-in ones. */ +@SlowHiveTest @ExtendedHiveTest class Hive_2_1_DDLSuite extends SparkFunSuite with TestHiveSingleton with BeforeAndAfterEach with BeforeAndAfterAll { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 2fe6a59a27c1b..920f6385f8e19 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.tags.SlowHiveTest import org.apache.spark.util.Utils case class Nested1(f1: Nested2) @@ -2559,6 +2560,8 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi } } +@SlowHiveTest class SQLQuerySuite extends SQLQuerySuiteBase with DisableAdaptiveExecutionSuite +@SlowHiveTest class SQLQuerySuiteAE extends SQLQuerySuiteBase with EnableAdaptiveExecutionSuite diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala index e6856a58b0ea9..1f1a5568b0201 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala @@ -119,6 +119,27 @@ object CountSerDeAgg extends Aggregator[Int, CountSerDeSQL, CountSerDeSQL] { def outputEncoder: Encoder[CountSerDeSQL] = ExpressionEncoder[CountSerDeSQL]() } +object ArrayDataAgg extends Aggregator[Array[Double], Array[Double], Array[Double]] { + def zero: Array[Double] = Array(0.0, 0.0, 0.0) + def reduce(s: Array[Double], array: Array[Double]): Array[Double] = { + require(s.length == array.length) + for ( j <- 0 until s.length ) { + s(j) += array(j) + } + s + } + def merge(s1: Array[Double], s2: Array[Double]): Array[Double] = { + require(s1.length == s2.length) + for ( j <- 0 until s1.length ) { + s1(j) += s2(j) + } + s1 + } + def finish(s: Array[Double]): Array[Double] = s + def bufferEncoder: Encoder[Array[Double]] = ExpressionEncoder[Array[Double]] + def outputEncoder: Encoder[Array[Double]] = ExpressionEncoder[Array[Double]] +} + abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ @@ -156,20 +177,11 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi (3, null, null)).toDF("key", "value1", "value2") data2.write.saveAsTable("agg2") - val data3 = Seq[(Seq[Integer], Integer, Integer)]( - (Seq[Integer](1, 1), 10, -10), - (Seq[Integer](null), -60, 60), - (Seq[Integer](1, 1), 30, -30), - (Seq[Integer](1), 30, 30), - (Seq[Integer](2), 1, 1), - (null, -10, 10), - (Seq[Integer](2, 3), -1, null), - (Seq[Integer](2, 3), 1, 1), - (Seq[Integer](2, 3, 4), null, 1), - (Seq[Integer](null), 100, -10), - (Seq[Integer](3), null, 3), - (null, null, null), - (Seq[Integer](3), null, null)).toDF("key", "value1", "value2") + val data3 = Seq[(Seq[Double], Int)]( + (Seq(1.0, 2.0, 3.0), 0), + (Seq(4.0, 5.0, 6.0), 0), + (Seq(7.0, 8.0, 9.0), 0) + ).toDF("data", "dummy") data3.write.saveAsTable("agg3") val data4 = Seq[Boolean](true, false, true).toDF("boolvalues") @@ -184,6 +196,7 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi spark.udf.register("mydoublesum", udaf(MyDoubleSumAgg)) spark.udf.register("mydoubleavg", udaf(MyDoubleAvgAgg)) spark.udf.register("longProductSum", udaf(LongProductSumAgg)) + spark.udf.register("arraysum", udaf(ArrayDataAgg)) } override def afterAll(): Unit = { @@ -354,6 +367,12 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil) } + test("SPARK-32159: array encoders should be resolved in analyzer") { + checkAnswer( + spark.sql("SELECT arraysum(data) FROM agg3"), + Row(Seq(12.0, 15.0, 18.0)) :: Nil) + } + test("verify aggregator ser/de behavior") { val data = sparkContext.parallelize((1 to 100).toSeq, 3).toDF("value1") val agg = udaf(CountSerDeAgg) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/security/HiveHadoopDelegationTokenManagerSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/security/HiveHadoopDelegationTokenManagerSuite.scala index 97eab4f3f4f77..f8f555197daef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/security/HiveHadoopDelegationTokenManagerSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/security/HiveHadoopDelegationTokenManagerSuite.scala @@ -52,8 +52,8 @@ class HiveHadoopDelegationTokenManagerSuite extends SparkFunSuite { throw new ClassNotFoundException(name) } - val prefixBlacklist = Seq("java", "scala", "com.sun.", "sun.") - if (prefixBlacklist.exists(name.startsWith(_))) { + val prefixExcludeList = Seq("java", "scala", "com.sun.", "sun.") + if (prefixExcludeList.exists(name.startsWith(_))) { return currentLoader.loadClass(name) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 683db21d3f0e1..37cc1b8a6d2ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -123,7 +123,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { val jobOption = outputStream.generateJob(time) jobOption.foreach(_.setCallSite(outputStream.creationSite)) jobOption - } + }.toSeq } logDebug("Generated " + jobs.length + " jobs for time " + time) jobs diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 2d53a1b4c78b6..af3f5a060f54b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -67,7 +67,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * @param master Name of the Spark Master * @param appName Name to be used when registering with the scheduler * @param batchDuration The time interval at which streaming data will be divided into batches - * @param sparkHome The SPARK_HOME directory on the slave nodes + * @param sparkHome The SPARK_HOME directory on the worker nodes * @param jarFile JAR file containing job code, to ship to cluster. This can be a path on the * local file system or an HDFS, HTTP, HTTPS, or FTP URL. */ @@ -84,7 +84,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * @param master Name of the Spark Master * @param appName Name to be used when registering with the scheduler * @param batchDuration The time interval at which streaming data will be divided into batches - * @param sparkHome The SPARK_HOME directory on the slave nodes + * @param sparkHome The SPARK_HOME directory on the worker nodes * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. */ @@ -101,7 +101,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * @param master Name of the Spark Master * @param appName Name to be used when registering with the scheduler * @param batchDuration The time interval at which streaming data will be divided into batches - * @param sparkHome The SPARK_HOME directory on the slave nodes + * @param sparkHome The SPARK_HOME directory on the worker nodes * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. * @param environment Environment variables to set on worker nodes @@ -366,7 +366,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) + sQueue ++= queue.asScala.map(_.rdd) ssc.queueStream(sQueue) } @@ -390,7 +390,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) + sQueue ++= queue.asScala.map(_.rdd) ssc.queueStream(sQueue, oneAtATime) } @@ -415,7 +415,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) + sQueue ++= queue.asScala.map(_.rdd) ssc.queueStream(sQueue, oneAtATime, defaultRDD.rdd) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala index ee8370d262609..7555e2f57fccb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala @@ -65,7 +65,7 @@ private[streaming] class JavaStreamingListenerWrapper(javaStreamingListener: Jav private def toJavaBatchInfo(batchInfo: BatchInfo): JavaBatchInfo = { JavaBatchInfo( batchInfo.batchTime, - batchInfo.streamIdToInputInfo.mapValues(toJavaStreamInputInfo(_)).asJava, + batchInfo.streamIdToInputInfo.mapValues(toJavaStreamInputInfo).toMap.asJava, batchInfo.submissionTime, batchInfo.processingStartTime.getOrElse(-1), batchInfo.processingEndTime.getOrElse(-1), @@ -73,7 +73,7 @@ private[streaming] class JavaStreamingListenerWrapper(javaStreamingListener: Jav batchInfo.processingDelay.getOrElse(-1), batchInfo.totalDelay.getOrElse(-1), batchInfo.numRecords, - batchInfo.outputOperationInfos.mapValues(toJavaOutputOperationInfo(_)).asJava + batchInfo.outputOperationInfos.mapValues(toJavaOutputOperationInfo).toMap.asJava ) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala index d46c0a01e05d9..2f4536ec6f0c2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -45,7 +45,7 @@ class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) s" time $validTime") } if (rdds.nonEmpty) { - Some(ssc.sc.union(rdds)) + Some(ssc.sc.union(rdds.toSeq)) } else { None } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index 8da5a5f8193cf..662312b7b0db8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -75,7 +75,7 @@ private[streaming] object MapWithStateRDDRecord { } } - MapWithStateRDDRecord(newStateMap, mappedData) + MapWithStateRDDRecord(newStateMap, mappedData.toSeq) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 6c71b18b46213..d038021e93e73 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -116,7 +116,9 @@ private[streaming] class ReceivedBlockTracker( // a few thousand elements. So we explicitly allocate a collection for serialization which // we know doesn't have this issue. (See SPARK-26734). val streamIdToBlocks = streamIds.map { streamId => - (streamId, mutable.ArrayBuffer(getReceivedBlockQueue(streamId).clone(): _*)) + val blocks = mutable.ArrayBuffer[ReceivedBlockInfo]() + blocks ++= getReceivedBlockQueue(streamId).clone() + (streamId, blocks.toSeq) }.toMap val allocatedBlocks = AllocatedBlocks(streamIdToBlocks) if (writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala index 4105171a3db24..0569abab1f36d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala @@ -135,7 +135,7 @@ private[streaming] class ReceiverSchedulingPolicy { leastScheduledExecutors += executor } - receivers.map(_.streamId).zip(scheduledLocations).toMap + receivers.map(_.streamId).zip(scheduledLocations.map(_.toSeq)).toMap } /** @@ -183,7 +183,7 @@ private[streaming] class ReceiverSchedulingPolicy { val executorWeights: Map[ExecutorCacheTaskLocation, Double] = { receiverTrackingInfoMap.values.flatMap(convertReceiverTrackingInfoToExecutorWeights) - .groupBy(_._1).mapValues(_.map(_._2).sum) // Sum weights for each executor + .groupBy(_._1).mapValues(_.map(_._2).sum).toMap // Sum weights for each executor } val idleExecutors = executors.toSet -- executorWeights.keys diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 13cf5cc0e71ea..342a0a43b5068 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -248,7 +248,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false _.runningExecutor.map { _.executorId } - } + }.toMap } else { Map.empty } @@ -415,7 +415,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } /** - * Run the dummy Spark job to ensure that all slaves have registered. This avoids all the + * Run the dummy Spark job to ensure that all executors have registered. This avoids all the * receivers to be scheduled on the same node. * * TODO Should poll the executor number and wait for executors according to diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index 31e4c6b59a64a..d0a3517af70b9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -170,7 +170,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp // We take the latest record for the timestamp. Please refer to the class Javadoc for // detailed explanation val time = sortedByTime.last.time - segment = wrappedLog.write(aggregate(sortedByTime), time) + segment = wrappedLog.write(aggregate(sortedByTime.toSeq), time) } buffer.foreach(_.promise.success(segment)) } catch { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index d33f83c819086..2e5000159bcb7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -146,7 +146,7 @@ private[streaming] class FileBasedWriteAheadLog( } else { // For performance gains, it makes sense to parallelize the recovery if // closeFileAfterWrite = true - seqToParIterator(executionContext, logFilesToRead, readFile).asJava + seqToParIterator(executionContext, logFilesToRead.toSeq, readFile).asJava } } @@ -277,10 +277,10 @@ private[streaming] object FileBasedWriteAheadLog { } def getCallerName(): Option[String] = { - val blacklist = Seq("WriteAheadLog", "Logging", "java.lang", "scala.") + val ignoreList = Seq("WriteAheadLog", "Logging", "java.lang", "scala.") Thread.currentThread.getStackTrace() .map(_.getClassName) - .find { c => !blacklist.exists(c.contains) } + .find { c => !ignoreList.exists(c.contains) } .flatMap(_.split("\\.").lastOption) .flatMap(_.split("\\$\\$").headOption) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index 9cdfdb8374322..e207dab7de068 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -93,7 +93,7 @@ object RawTextHelper { } /** - * Warms up the SparkContext in master and slave by running tasks to force JIT kick in + * Warms up the SparkContext in master and executor by running tasks to force JIT kick in * before real workload starts. */ def warmUp(sc: SparkContext): Unit = { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala index 36036fcd44b04..541a6e2d48b51 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -190,7 +190,7 @@ class DStreamScopeSuite assertDefined(foreachBaseScope) assert(foreachBaseScope.get.name === "foreachRDD") - val rddScopes = generatedRDDs.map { _.scope } + val rddScopes = generatedRDDs.map { _.scope }.toSeq assertDefined(rddScopes: _*) rddScopes.zipWithIndex.foreach { case (rddScope, idx) => assert(rddScope.get.name === "reduceByKey") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/scala/org/apache/spark/streaming/JavaTestUtils.scala index 0c4a64ccc513f..42a5aaba5178f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/JavaTestUtils.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/JavaTestUtils.scala @@ -36,7 +36,7 @@ trait JavaTestBase extends TestSuiteBase { ssc: JavaStreamingContext, data: JList[JList[T]], numPartitions: Int): JavaDStream[T] = { - val seqData = data.asScala.map(_.asScala) + val seqData = data.asScala.map(_.asScala.toSeq).toSeq implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index d0a5ababc7cac..9d735a32f7090 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -284,7 +284,7 @@ object MasterFailureTest extends Logging { }) } } - mergedOutput + mergedOutput.toSeq } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index bb60d6fa7bf78..60e04403937a2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -612,7 +612,7 @@ object WriteAheadLogSuite { } } writer.close() - segments + segments.toSeq } /** @@ -685,7 +685,7 @@ object WriteAheadLogSuite { } finally { reader.close() } - buffer + buffer.toSeq } /** Read all the data from a log file using reader class and return the list of byte buffers. */
Writable TypePython Type
Textunicode str
Textstr
IntWritableint
FloatWritablefloat
DoubleWritablefloat
Mesos Slave ID{state.slaveId.getValue}Mesos Agent ID{state.agentId.getValue}
Mesos Task ID{state.driverDescription.command.mainClass} cpus: {state.driverDescription.cores}, mem: {state.driverDescription.mem} {UIUtils.formatDate(state.startDate)}{state.slaveId.getValue}{state.agentId.getValue} {stateString(state.mesosTaskStatus)} {sandboxCol}